diff --git a/pkg/epp/flowcontrol/contracts/doc.go b/pkg/epp/flowcontrol/contracts/doc.go new file mode 100644 index 000000000..6e5b7834d --- /dev/null +++ b/pkg/epp/flowcontrol/contracts/doc.go @@ -0,0 +1,28 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package contracts defines the service interfaces that decouple the core `controller.FlowController` engine from its +// primary dependencies. In alignment with a "Ports and Adapters" (or "Hexagonal") architectural style, these +// interfaces represent the "ports" through which the engine communicates with external systems and pluggable logic. +// +// The two primary contracts defined here are: +// +// - `FlowRegistry`: The interface for the stateful control plane that manages the lifecycle of all flows, queues, and +// policies. +// +// - `SaturationDetector`: The interface for a component that provides real-time load signals, allowing the dispatch +// engine to react to backend saturation. +package contracts diff --git a/pkg/epp/flowcontrol/contracts/errors.go b/pkg/epp/flowcontrol/contracts/errors.go index fd46ec710..16acdbe28 100644 --- a/pkg/epp/flowcontrol/contracts/errors.go +++ b/pkg/epp/flowcontrol/contracts/errors.go @@ -20,11 +20,12 @@ import "errors" // Registry Errors var ( - // ErrFlowInstanceNotFound indicates that a requested flow instance (a `ManagedQueue`) does not exist in the registry - // shard, either because the flow is not registered or the specific instance (e.g., a draining queue at a particular - // priority) is not present. + // ErrFlowInstanceNotFound indicates that a requested flow instance (a `ManagedQueue`) does not exist. ErrFlowInstanceNotFound = errors.New("flow instance not found") + // ErrFlowIDEmpty indicates that a flow specification was provided with an empty flow ID. + ErrFlowIDEmpty = errors.New("flow ID cannot be empty") + // ErrPriorityBandNotFound indicates that a requested priority band does not exist in the registry because it was not // part of the initial configuration. ErrPriorityBandNotFound = errors.New("priority band not found") @@ -32,4 +33,7 @@ var ( // ErrPolicyQueueIncompatible indicates that a selected policy is not compatible with the capabilities of the queue it // is intended to operate on. For example, a policy requiring priority-based peeking is used with a simple FIFO queue. ErrPolicyQueueIncompatible = errors.New("policy is not compatible with queue capabilities") + + // ErrInvalidShardCount indicates that an invalid shard count was provided (e.g., zero). + ErrInvalidShardCount = errors.New("invalid shard count") ) diff --git a/pkg/epp/flowcontrol/contracts/registry.go b/pkg/epp/flowcontrol/contracts/registry.go index 843f501ff..7709f4747 100644 --- a/pkg/epp/flowcontrol/contracts/registry.go +++ b/pkg/epp/flowcontrol/contracts/registry.go @@ -14,18 +14,147 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package contracts defines the service interfaces that decouple the core `controller.FlowController` engine from its -// primary dependencies. In alignment with a "Ports and Adapters" (or "Hexagonal") architectural style, these -// interfaces represent the "ports" through which the engine communicates. -// -// This package contains the primary service contracts for the Flow Registry, which acts as the control plane for all -// flow state and configuration. package contracts import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" ) +// FlowRegistry is the complete interface for the global control plane, composed of administrative functions and the +// ability to provide shard accessors. A concrete implementation of this interface is the single source of truth for all +// flow control state and configuration. +// +// # Conformance +// +// All methods defined in this interface (including those embedded) MUST be goroutine-safe. +// Implementations are expected to perform complex updates (e.g., `RegisterOrUpdateFlow`, `UpdateShardCount`) atomically +// to preserve system invariants. +// +// # Invariants +// +// Concrete implementations of FlowRegistry MUST uphold the following invariants across all operations: +// 1. Shard Consistency: All configured priority bands and logical flows must be represented on every Active internal +// shard. Plugin instance types (e.g., the specific `framework.SafeQueue` implementation or policy plugins) must be +// consistent for a given flow or band across all shards. +// 2. Flow Instance Uniqueness per Band: For any given logical flow, there can be a maximum of one `ManagedQueue` +// instance per priority band. An instance can be either Active or Draining. +// 3. Single Active Instance per Flow: For any given logical flow, there can be a maximum of one Active `ManagedQueue` +// instance across all priority bands. All other instances for that flow must be in a Draining state. +// 4. Capacity Partitioning Consistency: Global and per-band capacity limits are uniformly partitioned across all +// active shards. The sum of the capacity limits allocated to each shard must not exceed the globally configured +// limits. +// +// # Flow Lifecycle States +// +// - Registered: A logical flow is Registered when it is known to the `FlowRegistry`. It has exactly one Active +// instance across all priority bands and zero or more Draining instances. +// - Active: A specific instance of a flow within a priority band is Active if it is the designated target for all +// new enqueues for that logical flow. +// - Draining: A flow instance is Draining if it no longer accepts new enqueues but still contains items that are +// eligible for dispatch. This occurs after a priority change. +// - Garbage Collected (Unregistered): A logical flow is automatically unregistered and garbage collected by the +// system when it has been 'idle' for a configurable period. A flow is considered idle if its active queue instance +// has been empty on all active shards for the timeout duration. Once unregistered, it has no active instances, +// though draining instances from previous priority levels may still exist until their queues are also empty. +// +// # Shard Garbage Collection +// +// When a shard is decommissioned via `UpdateShardCount`, the `FlowRegistry` must ensure a graceful shutdown. It must +// mark the shard as inactive to prevent new enqueues, allow the `FlowController` to continue draining its queues, and +// only delete the shard's state after the associated worker has fully terminated and all queues are empty. +type FlowRegistry interface { + FlowRegistryAdmin + ShardProvider +} + +// FlowRegistryAdmin defines the administrative interface for the global control plane. This interface is intended for +// external systems to configure flows, manage system parallelism, and query aggregated statistics for observability. +// +// # Design Rationale for Dynamic Update Strategies +// +// The `FlowRegistryAdmin` contract specifies precise behaviors for handling dynamic updates. These strategies were +// chosen to prioritize system stability, correctness, and minimal disruption: +// +// - Graceful Draining (for Priority/Shard Lifecycle Changes): For operations that change a flow's priority or +// decommission a shard, the affected queue instances are marked as inactive but are not immediately deleted. They +// enter a Draining state where they no longer accept new requests but are still processed for dispatch. This +// ensures that requests already accepted by the system are processed to completion. Crucially, requests in a +// draining queue continue to be dispatched according to the priority level and policies they were enqueued with, +// ensuring consistency. +// +// - Atomic Queue Migration (Future Design for Incompatible Intra-Flow Policy Changes): When an intra-flow policy is +// updated to one that is incompatible with the existing queue data structure, the designed future behavior is a +// full "drain and re-enqueue" migration. This more disruptive operation is necessary to guarantee correctness. A +// simpler "graceful drain"—by creating a second instance of the same flow in the same priority band—is not used +// because it would violate the system's "one flow instance per band" invariant. This invariant is critical because +// it ensures that inter-flow policies operate on a clean set of distinct flows, stateful intra-flow policies have a +// single authoritative view of their flow's state, and lookups are unambiguous. Note: This atomic migration is a +// future design consideration and is not implemented in the current version. +// +// - Self-Balancing on Shard Scale-Up: When new shards are added via `UpdateShardCount`, the framework relies on the +// `FlowController`'s request distribution logic (e.g., a "Join the Shortest Queue by Bytes (JSQ-Bytes)" strategy) +// to naturally funnel *new* requests to the less-loaded shards. This design choice strategically avoids the +// complexity of actively migrating or rebalancing existing items that are already queued on other shards, promoting +// system stability during scaling events. +type FlowRegistryAdmin interface { + // RegisterOrUpdateFlow handles the registration of a new flow or the update of an existing flow's specification. + // This method orchestrates complex state transitions atomically across all managed shards. + // + // # Dynamic Update Behaviors + // + // - Priority Changes: If a flow's priority level changes, its current active `ManagedQueue` instance is marked + // as inactive to drain existing requests. A new instance is activated at the new priority level. If a flow is + // updated to a priority level where an instance is already draining (e.g., during a rapid rollback), that + // draining instance is re-activated. + // + // # Returns + // + // - nil on success. + // - An error wrapping `ErrFlowIDEmpty` if `spec.ID` is empty. + // - An error wrapping`ErrPriorityBandNotFound` if `spec.Priority` refers to an unconfigured priority level. + // - Other errors if internal creation/activation of policy or queue instances fail. + RegisterOrUpdateFlow(spec types.FlowSpecification) error + + // UpdateShardCount dynamically adjusts the number of internal state shards, triggering a state rebalance. + // + // # Dynamic Update Behaviors + // + // - On Increase: New, empty state shards are initialized with all registered flows. The + // `controller.FlowController`'s request distribution logic will naturally balance load to these new shards over + // time. + // - On Decrease: A specified number of existing shards are marked as inactive. They stop accepting new requests + // but continue to drain existing items. They are fully removed only after their queues are empty. + // + // The implementation MUST atomically re-partition capacity allocations across all active shards when the count + // changes. + UpdateShardCount(n uint) error + + // Stats returns globally aggregated statistics for the entire `FlowRegistry`. + Stats() AggregateStats + + // ShardStats returns a slice of statistics, one for each internal shard. This provides visibility for debugging and + // monitoring per-shard behavior (e.g., identifying hot or stuck shards). + ShardStats() []ShardStats +} + +// ShardProvider defines a minimal interface for consumers that need to discover and iterate over available shards. +// +// A "shard" is an internal, parallel execution unit that allows the `FlowController`'s core dispatch logic to be +// parallelized. Consumers of this interface, such as a request distributor, MUST check `RegistryShard.IsActive()` +// before routing new work to a shard to ensure they do not send requests to a shard that is gracefully draining. +type ShardProvider interface { + // Shards returns a slice of accessors, one for each internal state shard. + // + // A "shard" is an internal, parallel execution unit that allows the `controller.FlowController`'s core dispatch logic + // to be parallelized, preventing a CPU bottleneck at high request rates. The `FlowRegistry`'s state is sharded to + // support this parallelism by reducing lock contention. + // + // The returned slice includes accessors for both active and draining shards. Consumers MUST use `IsActive()` to + // determine if new work should be routed to a shard. Callers should not modify the returned slice. + Shards() []RegistryShard +} + // RegistryShard defines the read-oriented interface that a `controller.FlowController` worker uses to access its // specific slice (shard) of the `FlowRegistry`'s state. It provides the necessary methods for a worker to perform its // dispatch operations by accessing queues and policies in a concurrent-safe manner. @@ -80,14 +209,14 @@ type RegistryShard interface { } // ManagedQueue defines the interface for a flow's queue instance on a specific shard. -// It wraps an underlying `framework.SafeQueue`, augmenting it with lifecycle validation against the `FlowRegistry` and -// integrating atomic statistics updates. +// It acts as a stateful decorator around an underlying `framework.SafeQueue`, augmenting it with lifecycle validation +// against the `FlowRegistry` and integrating atomic statistics updates. // // # Conformance // -// - All methods (including those embedded from `framework.SafeQueue`) MUST be goroutine-safe. -// - The `Add()` method MUST reject new items if the queue has been marked as "draining" by the `FlowRegistry`, -// ensuring that lifecycle changes are respected even by consumers holding a stale pointer to the queue. +// - All methods defined by this interface and the `framework.SafeQueue` it wraps MUST be goroutine-safe. +// - The `Add()` method MUST reject new items if the queue has been marked as Draining by the `FlowRegistry`, ensuring +// that lifecycle changes are respected even by consumers holding a stale pointer to the queue. // - All mutating methods (`Add()`, `Remove()`, `Cleanup()`, `Drain()`) MUST atomically update relevant statistics // (e.g., queue length, byte size). type ManagedQueue interface { @@ -100,6 +229,18 @@ type ManagedQueue interface { FlowQueueAccessor() framework.FlowQueueAccessor } +// AggregateStats holds globally aggregated statistics for the entire `FlowRegistry`. +type AggregateStats struct { + // TotalCapacityBytes is the globally configured maximum total byte size limit across all priority bands and shards. + TotalCapacityBytes uint64 + // TotalByteSize is the total byte size of all items currently queued across the entire system. + TotalByteSize uint64 + // TotalLen is the total number of items currently queued across the entire system. + TotalLen uint64 + // PerPriorityBandStats maps each configured priority level to its globally aggregated statistics. + PerPriorityBandStats map[uint]PriorityBandStats +} + // ShardStats holds statistics for a single internal shard within the `FlowRegistry`. type ShardStats struct { // TotalCapacityBytes is the optional, maximum total byte size limit aggregated across all priority bands within this @@ -112,6 +253,7 @@ type ShardStats struct { // TotalLen is the total number of items currently queued across all priority bands within this shard. TotalLen uint64 // PerPriorityBandStats maps each configured priority level to its statistics within this shard. + // The capacity values within represent this shard's partition of the global band capacity. // The key is the numerical priority level. // All configured priority levels are guaranteed to be represented. PerPriorityBandStats map[uint]PriorityBandStats @@ -138,9 +280,9 @@ type PriorityBandStats struct { Priority uint // PriorityName is an optional, human-readable name for the priority level (e.g., "Critical", "Sheddable"). PriorityName string - // CapacityBytes is the configured maximum total byte size for this priority band, aggregated across all items in - // all flow queues within this band. If scoped to a shard, its value represents the configured band limit for the - // `FlowRegistry` partitioned for this shard. + // CapacityBytes is the configured maximum total byte size for this priority band. + // When viewed via `AggregateStats`, this is the global limit. When viewed via `ShardStats`, this is the partitioned + // value for that specific shard. // The `controller.FlowController` enforces this limit. // A default non-zero value is guaranteed if not configured. CapacityBytes uint64 diff --git a/pkg/epp/flowcontrol/registry/config.go b/pkg/epp/flowcontrol/registry/config.go index ed98ab830..7b09a484d 100644 --- a/pkg/epp/flowcontrol/registry/config.go +++ b/pkg/epp/flowcontrol/registry/config.go @@ -19,6 +19,7 @@ package registry import ( "errors" "fmt" + "time" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" @@ -30,6 +31,22 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue/listqueue" ) +const ( + // defaultPriorityBandMaxBytes is the default capacity for a priority band if not explicitly configured. + // It is set to 1 GB. + defaultPriorityBandMaxBytes = 1_000_000_000 + // defaultIntraFlowDispatchPolicy is the default policy for selecting items within a single flow's queue. + defaultIntraFlowDispatchPolicy = fcfs.FCFSPolicyName + // defaultInterFlowDispatchPolicy is the default policy for selecting which flow's queue to service next. + defaultInterFlowDispatchPolicy = besthead.BestHeadPolicyName + // defaultQueue is the default queue implementation for flows. + defaultQueue = listqueue.ListQueueName + // defaultFlowGCTimeout is the default duration of inactivity after which an idle flow is garbage collected. + defaultFlowGCTimeout = 5 * time.Minute + // defaultEventChannelBufferSize is the default size of the buffered channel for control plane events. + defaultEventChannelBufferSize = 4096 +) + // Config holds the master configuration for the entire `FlowRegistry`. It serves as the top-level blueprint, defining // global capacity limits and the structure of its priority bands. // @@ -47,50 +64,93 @@ type Config struct { // // Required: At least one `PriorityBandConfig` must be provided for a functional registry. PriorityBands []PriorityBandConfig + + // FlowGCTimeout defines the duration of inactivity after which an idle flow is automatically unregistered and + // garbage collected. A flow is considered inactive if its queues have been empty across all shards for this + // duration. + // + // Optional: Defaults to 5 minutes. + FlowGCTimeout time.Duration + + // EventChannelBufferSize defines the size of the buffered channel used for internal control plane events. + // A larger buffer can absorb larger bursts of events (e.g., from many queues becoming non-empty simultaneously) + // without blocking the data path, but consumes more memory. + // + // Optional: Defaults to 4096. + EventChannelBufferSize uint32 + + // priorityBandMap is a cache for O(1) lookups of PriorityBandConfig by priority level. + // It is populated during validateAndApplyDefaults and when the config is partitioned or copied. + priorityBandMap map[uint]*PriorityBandConfig } -// partition calculates and returns a new `Config` with capacity values partitioned for a specific shard. -// This method ensures that the total capacity is distributed as evenly as possible across all shards. +// partition calculates and returns a new `Config` with capacity values partitioned for a specific shard. This method +// ensures that the total capacity is distributed as evenly as possible across all shards by distributing the remainder +// of the division one by one to the first few shards. func (c *Config) partition(shardIndex, totalShards int) (*Config, error) { if totalShards <= 0 || shardIndex < 0 || shardIndex >= totalShards { return nil, fmt.Errorf("invalid shard partitioning arguments: shardIndex=%d, totalShards=%d", shardIndex, totalShards) } - partitionValue := func(total uint64) uint64 { - if total == 0 { - return 0 - } - base := total / uint64(totalShards) - remainder := total % uint64(totalShards) - if uint64(shardIndex) < remainder { - return base + 1 - } - return base - } - newCfg := &Config{ - MaxBytes: partitionValue(c.MaxBytes), - PriorityBands: make([]PriorityBandConfig, len(c.PriorityBands)), + MaxBytes: partitionUint64(c.MaxBytes, shardIndex, totalShards), + FlowGCTimeout: c.FlowGCTimeout, + EventChannelBufferSize: c.EventChannelBufferSize, + PriorityBands: make([]PriorityBandConfig, len(c.PriorityBands)), + priorityBandMap: make(map[uint]*PriorityBandConfig, len(c.PriorityBands)), } for i, band := range c.PriorityBands { - newBand := band // Copy the original config - newBand.MaxBytes = partitionValue(band.MaxBytes) // Overwrite with the partitioned value + newBand := band // Copy the original config + newBand.MaxBytes = partitionUint64(band.MaxBytes, shardIndex, totalShards) // Overwrite with the partitioned value newCfg.PriorityBands[i] = newBand } + // Populate the map for the new partitioned config. + for i := range newCfg.PriorityBands { + band := &newCfg.PriorityBands[i] + newCfg.priorityBandMap[band.Priority] = band + } + return newCfg, nil } -// validateAndApplyDefaults checks the configuration for validity and populates any empty fields with system defaults. -// This method should be called once by the registry before it initializes any shards. +// partitionUint64 distributes a total uint64 value across a number of partitions. +// It distributes the remainder of the division one by one to the first few partitions. +func partitionUint64(total uint64, partitionIndex, totalPartitions int) uint64 { + if total == 0 { + return 0 + } + base := total / uint64(totalPartitions) + remainder := total % uint64(totalPartitions) + // Distribute the remainder. The first `remainder` partitions get one extra from the total. + // For example, if total=10 and partitions=3, base=3, remainder=1. Partition 0 gets 3+1=4, partitions 1 and 2 get 3. + if uint64(partitionIndex) < remainder { + base++ + } + return base +} + +// validateAndApplyDefaults checks the configuration for validity and mutates it to populate any empty fields with +// system defaults. It ensures that all priority bands are well-formed, have unique priority levels and names, and that +// their chosen plugins are compatible. This method should be called once by the registry before it initializes any +// shards. func (c *Config) validateAndApplyDefaults() error { + if c.FlowGCTimeout <= 0 { + c.FlowGCTimeout = defaultFlowGCTimeout + } + if c.EventChannelBufferSize <= 0 { + c.EventChannelBufferSize = defaultEventChannelBufferSize + } + if len(c.PriorityBands) == 0 { return errors.New("config validation failed: at least one priority band must be defined") } - priorities := make(map[uint]struct{}) // Keep track of seen priorities + priorities := make(map[uint]struct{}) + priorityNames := make(map[string]struct{}) + c.priorityBandMap = make(map[uint]*PriorityBandConfig, len(c.PriorityBands)) for i := range c.PriorityBands { band := &c.PriorityBands[i] @@ -100,27 +160,39 @@ func (c *Config) validateAndApplyDefaults() error { priorities[band.Priority] = struct{}{} if band.PriorityName == "" { - return errors.New("config validation failed: PriorityName is required for all priority bands") + return fmt.Errorf("config validation failed: PriorityName is required for priority band %d", band.Priority) + } + if _, exists := priorityNames[band.PriorityName]; exists { + return fmt.Errorf("config validation failed: duplicate priority name %q found", band.PriorityName) } + priorityNames[band.PriorityName] = struct{}{} + if band.IntraFlowDispatchPolicy == "" { - band.IntraFlowDispatchPolicy = fcfs.FCFSPolicyName + band.IntraFlowDispatchPolicy = defaultIntraFlowDispatchPolicy } if band.InterFlowDispatchPolicy == "" { - band.InterFlowDispatchPolicy = besthead.BestHeadPolicyName + band.InterFlowDispatchPolicy = defaultInterFlowDispatchPolicy } if band.Queue == "" { - band.Queue = listqueue.ListQueueName + band.Queue = defaultQueue + } + if band.MaxBytes == 0 { + band.MaxBytes = defaultPriorityBandMaxBytes } // After defaulting, validate that the chosen plugins are compatible. if err := validateBandCompatibility(*band); err != nil { return err } + // Populate the lookup map. + c.priorityBandMap[band.Priority] = band } return nil } -// validateBandCompatibility verifies that a band's default policy is compatible with its default queue type. +// validateBandCompatibility verifies that a band's configured queue type has the necessary capabilities to support its +// configured intra-flow dispatch policy. For example, a priority-based policy requires a queue that supports priority +// ordering. func validateBandCompatibility(band PriorityBandConfig) error { policy, err := intra.NewPolicyFromName(band.IntraFlowDispatchPolicy) if err != nil { @@ -182,19 +254,19 @@ type PriorityBandConfig struct { // IntraFlowDispatchPolicy specifies the default name of the registered policy used to select a specific request to // dispatch next from within a single flow's queue in this band. This default can be overridden on a per-flow basis. // - // Optional: If empty, a system default (e.g., "FCFS") is used. + // Optional: If empty, a system default ("FCFS") is used. IntraFlowDispatchPolicy intra.RegisteredPolicyName // InterFlowDispatchPolicy specifies the name of the registered policy used to select which flow's queue to service // next from this band. // - // Optional: If empty, a system default (e.g., "BestHead") is used. + // Optional: If empty, a system default ("BestHead") is used. InterFlowDispatchPolicy inter.RegisteredPolicyName // Queue specifies the default name of the registered SafeQueue implementation to be used for flow queues within this // band. // - // Optional: If empty, a system default (e.g., "ListQueue") is used. + // Optional: If empty, a system default ("ListQueue") is used. Queue queue.RegisteredQueueName // MaxBytes defines the maximum total byte size for this specific priority band, aggregated across all shards. @@ -202,3 +274,36 @@ type PriorityBandConfig struct { // Optional: If not set, a system default (e.g., 1 GB) is applied. MaxBytes uint64 } + +// getBandConfig finds and returns the configuration for a specific priority level using the O(1) lookup map. +func (c *Config) getBandConfig(priority uint) (*PriorityBandConfig, error) { + if band, ok := c.priorityBandMap[priority]; ok { + return band, nil + } + return nil, fmt.Errorf("config for priority %d not found: %w", priority, contracts.ErrPriorityBandNotFound) +} + +// deepCopy creates a deep copy of the Config object. +func (c *Config) deepCopy() *Config { + if c == nil { + return nil + } + newCfg := &Config{ + MaxBytes: c.MaxBytes, + FlowGCTimeout: c.FlowGCTimeout, + EventChannelBufferSize: c.EventChannelBufferSize, + PriorityBands: make([]PriorityBandConfig, len(c.PriorityBands)), + priorityBandMap: make(map[uint]*PriorityBandConfig, len(c.PriorityBands)), + } + // PriorityBandConfig is a struct of value types, so a direct copy of the struct + // is sufficient for a deep copy. The `copy` built-in creates a new slice and + // copies the struct values from the original slice into it. + copy(newCfg.PriorityBands, c.PriorityBands) + + // Rebuild the map so pointers refer to the new slice elements. + for i := range newCfg.PriorityBands { + band := &newCfg.PriorityBands[i] + newCfg.priorityBandMap[band.Priority] = band + } + return newCfg +} diff --git a/pkg/epp/flowcontrol/registry/config_test.go b/pkg/epp/flowcontrol/registry/config_test.go index 0cf9f93d8..e65e23f3c 100644 --- a/pkg/epp/flowcontrol/registry/config_test.go +++ b/pkg/epp/flowcontrol/registry/config_test.go @@ -19,6 +19,7 @@ package registry import ( "errors" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,8 +36,6 @@ import ( ) func TestConfig_ValidateAndApplyDefaults(t *testing.T) { - t.Parallel() - // Setup for failure injection tests failingPolicyName := intra.RegisteredPolicyName("failing-policy-for-config-test") intra.MustRegisterPolicy(failingPolicyName, func() (framework.IntraFlowDispatchPolicy, error) { @@ -76,20 +75,24 @@ func TestConfig_ValidateAndApplyDefaults(t *testing.T) { }, expectErr: false, expectedCfg: &Config{ + FlowGCTimeout: defaultFlowGCTimeout, + EventChannelBufferSize: defaultEventChannelBufferSize, PriorityBands: []PriorityBandConfig{ { Priority: 1, PriorityName: "High", - IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, - InterFlowDispatchPolicy: besthead.BestHeadPolicyName, - Queue: listqueue.ListQueueName, + IntraFlowDispatchPolicy: defaultIntraFlowDispatchPolicy, + InterFlowDispatchPolicy: defaultInterFlowDispatchPolicy, + Queue: defaultQueue, + MaxBytes: defaultPriorityBandMaxBytes, }, { Priority: 2, PriorityName: "Low", - IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + IntraFlowDispatchPolicy: defaultIntraFlowDispatchPolicy, InterFlowDispatchPolicy: roundrobin.RoundRobinPolicyName, - Queue: listqueue.ListQueueName, + Queue: defaultQueue, + MaxBytes: defaultPriorityBandMaxBytes, }, }, }, @@ -97,7 +100,9 @@ func TestConfig_ValidateAndApplyDefaults(t *testing.T) { { name: "Config with all fields specified and compatible", input: &Config{ - MaxBytes: 1000, + MaxBytes: 1000, + FlowGCTimeout: 10 * time.Minute, + EventChannelBufferSize: 5000, PriorityBands: []PriorityBandConfig{ { Priority: 1, @@ -111,7 +116,9 @@ func TestConfig_ValidateAndApplyDefaults(t *testing.T) { }, expectErr: false, expectedCfg: &Config{ // Should be unchanged - MaxBytes: 1000, + MaxBytes: 1000, + FlowGCTimeout: 10 * time.Minute, + EventChannelBufferSize: 5000, PriorityBands: []PriorityBandConfig{ { Priority: 1, @@ -148,6 +155,16 @@ func TestConfig_ValidateAndApplyDefaults(t *testing.T) { }, expectErr: true, }, + { + name: "Error: Duplicate priority name", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: 1, PriorityName: "High"}, + {Priority: 2, PriorityName: "High"}, + }, + }, + expectErr: true, + }, { name: "Error: Incompatible policy and queue", input: &Config{ @@ -191,20 +208,37 @@ func TestConfig_ValidateAndApplyDefaults(t *testing.T) { }, expectErr: true, }, + { + name: "Error: Non-existent queue name", + input: &Config{ + PriorityBands: []PriorityBandConfig{ + { + Priority: 1, + PriorityName: "High", + Queue: "non-existent-queue", + }, + }, + }, + expectErr: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - err := tc.input.validateAndApplyDefaults() + // Create a deep copy to prevent data races between parallel tests. + configCopy := tc.input.deepCopy() + expectedCfgCopy := tc.expectedCfg.deepCopy() + + err := configCopy.validateAndApplyDefaults() if tc.expectErr { - require.Error(t, err, "Expected an error for this test case") + require.Error(t, err, "validateAndApplyDefaults should have returned an error") if tc.expectedErrIs != nil { - assert.ErrorIs(t, err, tc.expectedErrIs, "Error should be of the expected type") + assert.ErrorIs(t, err, tc.expectedErrIs) } } else { - require.NoError(t, err, "Did not expect an error for this test case") - assert.Equal(t, tc.expectedCfg, tc.input, "Config after applying defaults does not match expected config") + require.NoError(t, err, "validateAndApplyDefaults should not have returned an error") + assert.Equal(t, expectedCfgCopy, configCopy, "Config should have been correctly defaulted") } }) } diff --git a/pkg/epp/flowcontrol/registry/doc.go b/pkg/epp/flowcontrol/registry/doc.go index 521ad6e7b..be9fe551e 100644 --- a/pkg/epp/flowcontrol/registry/doc.go +++ b/pkg/epp/flowcontrol/registry/doc.go @@ -14,27 +14,122 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package registry provides the concrete implementation of the Flow Registry. +// Package registry provides the concrete implementation of the `contracts.FlowRegistry`. // // As the stateful control plane for the entire Flow Control system, this package is responsible for managing the -// lifecycle of all flows, queues, and policies. It serves as the "adapter" that implements the service "ports" -// (interfaces) defined in the `contracts` package. It provides a sharded, concurrent-safe view of its state to the +// lifecycle of all flows, queues, and policies. It provides a sharded, concurrent-safe view of its state to the // `controller.FlowController` workers, enabling scalable, parallel request processing. // -// # Key Components +// # Architecture: Composite, Sharded, and Separated Concerns // -// - `FlowRegistry`: The future top-level administrative object that will manage the entire system, including shard -// lifecycles and flow registration. (Not yet implemented). +// The registry employs a composite architecture designed to separate the control plane (orchestration) from the data +// plane (request processing state). // -// - `registryShard`: A concrete implementation of the `contracts.RegistryShard` interface. It represents a single, -// concurrent-safe slice of the registry's state, containing a set of priority bands and the flow queues within -// them. This is the primary object a `controller.FlowController` worker interacts with. +// - `FlowRegistry`: The Control Plane. The top-level orchestrator and single source of truth. It manages the +// lifecycle of its child shards and centralizes all complex administrative operations, such as flow registration, +// garbage collection coordination, and shard scaling. // -// - `managedQueue`: A concrete implementation of the `contracts.ManagedQueue` interface. It acts as a stateful -// decorator around a `framework.SafeQueue`, adding critical registry-level functionality such as atomic statistics -// tracking and lifecycle state enforcement (active vs. draining). +// - `registryShard`: The Data Plane Slice. A single, concurrent-safe "slice" of the registry's total state. It is the +// primary object that a `controller.FlowController` worker interacts with, providing a simplified, read-optimized +// view of the queues and policies it needs to operate. // -// - `Config`: The top-level configuration object that defines the structure and default behaviors of the registry, -// including the definition of priority bands and default policy selections. This configuration is partitioned and -// distributed to each `registryShard`. +// - `managedQueue`: The Stateful Decorator. A wrapper around a generic `framework.SafeQueue`. It augments the queue +// with registry-specific concerns: atomic statistics tracking and lifecycle enforcement (active vs. draining). +// +// # Concurrency Model: Multi-Tiered Locking and the Actor Pattern +// +// The registry employs a multi-tiered concurrency strategy to maximize performance on the hot path while ensuring +// strict correctness for complex state transitions: +// +// 1. Serialized Control Plane (Actor Model): The `FlowRegistry` utilizes an Actor-like pattern for its control plane. +// A single background goroutine processes all state change events (e.g., GC timers, queue emptiness signals) from a +// channel. This serializes all mutations to the registry's core state, eliminating a significant class of race +// conditions and simplifying the complex logic of distributed garbage collection and flow lifecycles. +// +// 2. Coarse-Grained Admin Lock: A single `sync.Mutex` on the `FlowRegistry` protects administrative operations (e.g., +// `RegisterOrUpdateFlow`). This lock is acquired both by external callers and the internal event loop, ensuring +// that these complex, multi-step state changes appear atomic. +// +// 3. Shard-Level R/W Lock: Each `registryShard` uses a `sync.RWMutex` to protect its internal maps. This allows +// multiple `controller.FlowController` workers to read from the same shard in parallel. +// +// 4. Lock-Free Data Path (Atomics): All statistics (queue length, byte size) at all layers are implemented using +// `atomic.Uint64`. This allows for high-performance, lock-free updates on the "data plane" hot path, where every +// request modifies these counters. +// +// # The "Trust but Verify" Garbage Collection Pattern +// +// A critical aspect of the registry's design is the race condition between the asynchronous data path (e.g., a queue +// becoming non-empty) and the control plane's destructive operations (garbage collection). The `flowState` object in +// the control plane is an eventually consistent cache of the system's state. Relying on this cached view for a +// destructive action could lead to incorrect behavior, for instance, if a GC timer fires and is processed before the +// activity event that should have cancelled it. +// +// To solve this without sacrificing performance via synchronous locking on the hot path, the registry employs a +// "Trust but Verify" pattern for all garbage collection decisions: +// +// 1. Trust: The control plane first "trusts" its cached `flowState` to make a preliminary, non-destructive decision +// (e.g., "the flow appears to be idle"). +// +// 2. Verify: Before committing to the destructive action (deleting the flow), the control plane performs a "verify" +// step. It synchronously queries the ground truth—the atomic counters on the live `managedQueue` instances across +// all relevant shards. +// +// This pattern provides the necessary strong consistency for critical operations precisely when needed, incurring the +// overhead of the live check only during the GC process (which is off the hot path), thereby maintaining high +// performance on the request path. +// +// # Event-Driven State Machine and Dynamic Updates +// +// The registry is designed to handle dynamic configuration changes gracefully and correctly. The interplay between +// state transitions, the event-driven control plane, and the garbage collection (GC) system is critical to its +// robustness. +// +// The system relies on atomic state transitions to generate reliable, edge-triggered signals. Components (queues, +// shards) use atomic state transitions (e.g., transitioning from Draining to Drained) to signal the control plane +// exactly once when a critical event occurs. These signals are sent reliably over the event channel; if the channel is +// full, the sender blocks, applying necessary backpressure to ensure no events are lost, which is vital for preventing +// state divergence and memory leaks. +// +// The following scenarios detail how the registry handles various lifecycle events: +// +// New Flow Registration: A new flow `F` is registered at priority `P1`. +// +// 1. A new `managedQueue` (`Q1`) is created on all shards and marked Active. +// 2. The control plane (`FlowRegistry`) starts inactivity GC tracking. If `Q1` remains empty globally for +// `FlowGCTimeout`, flow `F` is automatically unregistered. +// +// Flow Activity/Inactivity: A flow transitions between having requests and being empty. +// +// 1. When the first request is enqueued, `Q1` signals `QueueBecameNonEmpty`. The control plane stops the GC timer. +// 2. When the last request is dispatched globally, `Q1` signals `QueueBecameEmpty`. The control plane starts the GC +// timer. +// +// Flow Priority Change: Flow `F` changes from priority `P1` to `P2`. +// +// 1. The existing queue (`Q1`) at `P1` transitions to Draining (stops accepting new requests). +// 2. A new queue (`Q2`) at `P2` is created and marked Active. +// 3. Inactivity GC tracking starts for `Q2`. +// 4. When `Q1` becomes empty globally, it transitions to Drained and signals `QueueBecameDrainedAndEmpty`. The control +// plane garbage collects Q1 instances. +// +// Draining Reactivation: Flow `F` changes `P1` -> `P2`, then quickly back `P2` -> `P1` before `Q1` is empty. +// +// 1. (`P1`->`P2`): `Q1` is Draining, `Q2` is Active. +// 2. (`P2`->`P1`): The system finds `Q1` and atomically transitions it back to Active. `Q2` transitions to Draining. +// 3. This optimization avoids unnecessary object churn. GC tracking is correctly updated. +// +// Shard Scale-Up: New shards are added. +// +// 1. New shards are created. +// 2. The `FlowRegistry` synchronizes all existing flows onto the new shards. +// 3. GC tracking state is initialized for these new queue instances. +// 4. Configuration (e.g., capacity limits) is re-partitioned across all active shards. +// +// Shard Scale-Down: The shard count is reduced. +// +// 1. Targeted shards transition to Draining. +// 2. Configuration is re-partitioned across the remaining active shards. +// 3. When a draining shard becomes completely empty, it transitions to Drained and signals `ShardDrainedAndEmpty`. +// 4. The control plane removes the shard and purges its ID from all flow tracking maps to prevent memory leaks. package registry diff --git a/pkg/epp/flowcontrol/registry/flowstate.go b/pkg/epp/flowcontrol/registry/flowstate.go new file mode 100644 index 000000000..7ebd65d52 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/flowstate.go @@ -0,0 +1,150 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + +// flowState holds all tracking state for a single logical flow within the registry. +// +// # Role: The Eventually Consistent Cache for GC +// +// This structure is central to the garbage collection (GC) logic. It acts as an eventually consistent, cached view of a +// flow's emptiness status across all shards. It is updated asynchronously via events from the data path and is used by +// the `FlowRegistry`'s control plane to make preliminary, non-destructive decisions about a flow's lifecycle (e.g., +// "the flow appears to be idle, start a GC timer"). +// +// # Concurrency and Consistency Model +// +// `flowState` is a passive, non-thread-safe data structure. It is owned and managed exclusively by the `FlowRegistry`'s +// single-threaded event loop. +// +// CRITICAL: Because this state is eventually consistent, it MUST NOT be the sole source of truth for any destructive +// operation (like garbage collection). All destructive actions must first verify the live state of the system by +// consulting the atomic counters on the `managedQueue` instances directly. This is part of the "Trust but Verify" +// pattern. +type flowState struct { + // spec is the desired state of the flow. + spec types.FlowSpecification + // generation is an internal counter to resolve races between GC timers and flow re-registration. + // This ensures that a stale timer event from a previous configuration does not incorrectly garbage collect the flow. + generation uint64 + // activeQueueEmptyOnShards tracks the empty status of the flow's single active queue across all shards. + // The key is the shard ID. + activeQueueEmptyOnShards map[string]bool + // drainingQueuesEmptyOnShards tracks the empty status of any draining queues for this flow. + // The key is the priority level of the draining queue, then the shard ID. + drainingQueuesEmptyOnShards map[uint]map[string]bool +} + +// newFlowState creates the initial state for a newly registered flow. +// It initializes the state based on the current set of active and draining shards. +func newFlowState(spec types.FlowSpecification, allShards []*registryShard) *flowState { + s := &flowState{ + spec: spec, + generation: 1, + activeQueueEmptyOnShards: make(map[string]bool, len(allShards)), + } + // New queues start empty on all shards (active and draining). + for _, shard := range allShards { + s.activeQueueEmptyOnShards[shard.id] = true + } + return s +} + +// update applies a new specification to the flow, incrementing its generation and handling priority change logic. +func (s *flowState) update(spec types.FlowSpecification, allShards []*registryShard) { + oldPriority := s.spec.Priority + s.spec = spec + s.generation++ // Invalidate any pending GC timers for the old generation. + + // If priority did not change, there's nothing more to do. + if oldPriority == spec.Priority { + return + } + + // Priority changed. The old active queue is now a draining queue. + // We transfer its current emptiness state from the active map to the draining map. + if s.drainingQueuesEmptyOnShards == nil { + s.drainingQueuesEmptyOnShards = make(map[uint]map[string]bool) + } + s.drainingQueuesEmptyOnShards[oldPriority] = s.activeQueueEmptyOnShards + + // After an update, the new active queue must be re-evaluated. + // Check if we are reactivating a previously draining queue. + if drainingState, ok := s.drainingQueuesEmptyOnShards[spec.Priority]; ok { + // Yes, we are reactivating. The draining state becomes the new active state. + s.activeQueueEmptyOnShards = drainingState + delete(s.drainingQueuesEmptyOnShards, spec.Priority) + } else { + // No, this is a new priority. It starts as empty on all shards. + s.activeQueueEmptyOnShards = make(map[string]bool, len(allShards)) + for _, shard := range allShards { + s.activeQueueEmptyOnShards[shard.id] = true + } + } +} + +// handleQueueSignal updates the flow's internal state based on a signal from one of its queues. +func (s *flowState) handleQueueSignal(shardID string, priority uint, signal queueStateSignal) { + switch signal { + case queueStateSignalBecameDrained: + if priorityState, ok := s.drainingQueuesEmptyOnShards[priority]; ok { + priorityState[shardID] = true + } + case queueStateSignalBecameEmpty: + s.activeQueueEmptyOnShards[shardID] = true + case queueStateSignalBecameNonEmpty: + s.activeQueueEmptyOnShards[shardID] = false + } +} + +// isIdle checks if the flow's active queues are empty across all active shards. +// A flow is considered idle even if it has items remaining in a shard that is itself draining. +func (s *flowState) isIdle(activeShards []*registryShard) bool { + for _, shard := range activeShards { + // We rely on the caller (FlowRegistry) to provide only the active shards. + if !s.activeQueueEmptyOnShards[shard.id] { + return false + } + } + return true +} + +// isDrained checks if a specific draining queue is now empty on all shards (active and draining). +func (s *flowState) isDrained(priority uint, allShards []*registryShard) bool { + priorityState, ok := s.drainingQueuesEmptyOnShards[priority] + if !ok { + // If the priority isn't in the map, it's not currently draining, so it cannot be complete. + return false + } + + // We must check both active and draining shards, as the queue instance exists on both until GC'd. + for _, shard := range allShards { + if !priorityState[shard.id] { + return false + } + } + return true +} + +// purgeShard removes a decommissioned shard's ID from all tracking maps to prevent memory leaks. +func (s *flowState) purgeShard(shardID string) { + delete(s.activeQueueEmptyOnShards, shardID) + for _, priorityState := range s.drainingQueuesEmptyOnShards { + delete(priorityState, shardID) + } +} diff --git a/pkg/epp/flowcontrol/registry/flowstate_test.go b/pkg/epp/flowcontrol/registry/flowstate_test.go new file mode 100644 index 000000000..0fa4401f3 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/flowstate_test.go @@ -0,0 +1,291 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" +) + +// fsTestHarness holds all the common components for a `flowState` test. +type fsTestHarness struct { + t *testing.T + fs *flowState + allShards []*registryShard + activeShards []*registryShard + spec types.FlowSpecification +} + +// newFsTestHarness creates a new test harness, initializing a `flowState` with a default spec and shard layout. +func newFsTestHarness(t *testing.T) *fsTestHarness { + t.Helper() + + spec := types.FlowSpecification{ID: "f1", Priority: pHigh} + activeShards := []*registryShard{{id: "s1"}, {id: "s2"}} + allShards := append(activeShards, ®istryShard{id: "s3-draining"}) // `s3` is conceptually draining + + fs := newFlowState(spec, allShards) + + return &fsTestHarness{ + t: t, + fs: fs, + allShards: allShards, + activeShards: activeShards, + spec: spec, + } +} + +func TestFlowState(t *testing.T) { + t.Parallel() + + t.Run("New", func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + + assert.Equal(t, h.spec, h.fs.spec, "Spec should be correctly initialized") + assert.Equal(t, uint64(1), h.fs.generation, "Initial generation should be 1") + require.Len(t, h.fs.activeQueueEmptyOnShards, 3, "Should track all initial shards") + + // New queues should start empty on all shards. + assert.True(t, h.fs.activeQueueEmptyOnShards["s1"], "Queue on s1 should start empty") + assert.True(t, h.fs.activeQueueEmptyOnShards["s2"], "Queue on s2 should start empty") + assert.True(t, h.fs.activeQueueEmptyOnShards["s3-draining"], "Queue on s3-draining should start empty") + }) + + t.Run("Update", func(t *testing.T) { + t.Parallel() + + specHigh := types.FlowSpecification{ID: "f1", Priority: pHigh} + specLow := types.FlowSpecification{ID: "f1", Priority: pLow} + + testCases := []struct { + name string + setup func(h *fsTestHarness) + updatedSpec types.FlowSpecification + assertions func(h *fsTestHarness) + }{ + { + name: "WhenPriorityIsUnchanged_ShouldPreserveState", + setup: func(h *fsTestHarness) { + h.fs.handleQueueSignal("s1", pHigh, queueStateSignalBecameNonEmpty) + }, + updatedSpec: specHigh, + assertions: func(h *fsTestHarness) { + assert.Equal(t, uint64(2), h.fs.generation, "Generation should increment on any update") + assert.Equal(t, specHigh, h.fs.spec, "Spec should be updated") + assert.Empty(t, h.fs.drainingQueuesEmptyOnShards, "Draining map should be empty when priority does not change") + assert.False(t, h.fs.activeQueueEmptyOnShards["s1"], "State of non-emptiness should be preserved") + }, + }, + { + name: "WhenPriorityChanges_ShouldMoveOldActiveToDraining", + setup: func(h *fsTestHarness) { + h.fs.handleQueueSignal("s1", pHigh, queueStateSignalBecameNonEmpty) + require.False(t, h.fs.activeQueueEmptyOnShards["s1"], "Test setup: s1 should be non-empty") + }, + updatedSpec: specLow, + assertions: func(h *fsTestHarness) { + assert.Equal(t, uint64(2), h.fs.generation, "Generation should increment after update") + assert.Equal(t, specLow, h.fs.spec, "Spec should be updated to new priority") + + require.Contains(t, h.fs.drainingQueuesEmptyOnShards, uint(pHigh), "pHigh should now be in the draining map") + drainingState := h.fs.drainingQueuesEmptyOnShards[pHigh] + assert.False(t, drainingState["s1"], "s1 should still be non-empty for pHigh (draining)") + assert.True(t, drainingState["s2"], "s2 should still be empty for pHigh (draining)") + + assert.True(t, h.fs.activeQueueEmptyOnShards["s1"], "s1 should start empty for new pLow (active)") + assert.True(t, h.fs.activeQueueEmptyOnShards["s2"], "s2 should start empty for new pLow (active)") + }, + }, + { + name: "WhenPriorityRollsBack_ShouldReactivateDrainingQueue", + setup: func(h *fsTestHarness) { + // Step 1: `pHigh` -> `pLow`. This makes `pHigh` draining. + h.fs.update(specLow, h.allShards) + require.Contains(t, h.fs.drainingQueuesEmptyOnShards, uint(pHigh), "Test setup: pHigh should be draining") + require.Equal(t, uint64(2), h.fs.generation, "Test setup: generation should be 2") + }, + updatedSpec: specHigh, // This is the rollback update. + assertions: func(h *fsTestHarness) { + assert.Equal(t, uint64(3), h.fs.generation, "Generation should increment upon reactivation") + assert.NotContains(t, h.fs.drainingQueuesEmptyOnShards, uint(pHigh), + "pHigh should be removed from draining map on reactivation") + assert.Contains(t, h.fs.drainingQueuesEmptyOnShards, uint(pLow), "pLow should now be in the draining map") + assert.Equal(t, specHigh, h.fs.spec, "Spec should be back to pHigh") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + if tc.setup != nil { + tc.setup(h) + } + + h.fs.update(tc.updatedSpec, h.allShards) + + tc.assertions(h) + }) + } + }) + + t.Run("HandleQueueSignal", func(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + signal queueStateSignal + shardID string + priority uint + setup func(h *fsTestHarness) + assertState func(h *fsTestHarness) + }{ + { + name: "BecameNonEmpty_ShouldMarkActiveQueueAsNonEmpty", + signal: queueStateSignalBecameNonEmpty, + shardID: "s1", + priority: pHigh, + assertState: func(h *fsTestHarness) { assert.False(t, h.fs.activeQueueEmptyOnShards["s1"]) }, + }, + { + name: "BecameEmpty_ShouldMarkActiveQueueAsEmpty", + signal: queueStateSignalBecameEmpty, + shardID: "s1", + priority: pHigh, + setup: func(h *fsTestHarness) { h.fs.activeQueueEmptyOnShards["s1"] = false }, + assertState: func(h *fsTestHarness) { assert.True(t, h.fs.activeQueueEmptyOnShards["s1"]) }, + }, + { + name: "BecameDrained_ShouldMarkDrainingQueueAsEmpty", + signal: queueStateSignalBecameDrained, + shardID: "s2", + priority: pHigh, // The original priority, which is now draining + setup: func(h *fsTestHarness) { + h.fs.update(types.FlowSpecification{ID: "f1", Priority: pLow}, h.allShards) + h.fs.drainingQueuesEmptyOnShards[pHigh]["s2"] = false + }, + assertState: func(h *fsTestHarness) { assert.True(t, h.fs.drainingQueuesEmptyOnShards[pHigh]["s2"]) }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + if tc.setup != nil { + tc.setup(h) + } + h.fs.handleQueueSignal(tc.shardID, tc.priority, tc.signal) + tc.assertState(h) + }) + } + }) + + t.Run("IsIdle", func(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + setup func(h *fsTestHarness) + expectIdle bool + }{ + {"WhenAllActiveShardsAreEmpty_ShouldReturnTrue", nil, true}, + { + "WhenOneActiveShardIsNonEmpty_ShouldReturnFalse", + func(h *fsTestHarness) { h.fs.handleQueueSignal("s1", pHigh, queueStateSignalBecameNonEmpty) }, + false, + }, + { + "WhenOnlyDrainingShardsAreNonEmpty_ShouldReturnTrue", + func(h *fsTestHarness) { h.fs.handleQueueSignal("s3-draining", pHigh, queueStateSignalBecameNonEmpty) }, + true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + if tc.setup != nil { + tc.setup(h) + } + assert.Equal(t, tc.expectIdle, h.fs.isIdle(h.activeShards)) + }) + } + }) + + t.Run("IsDrained", func(t *testing.T) { + t.Parallel() + drainingPriority := uint(pHigh) + activePriority := uint(pLow) + testCases := []struct { + name string + priorityToTest uint + setup func(h *fsTestHarness) + expectDrained bool + }{ + {"WhenQueueIsEmptyOnAllShards_ShouldReturnTrue", drainingPriority, nil, true}, + { + "WhenQueueIsNonEmptyOnActiveShard_ShouldReturnFalse", + drainingPriority, + func(h *fsTestHarness) { h.fs.drainingQueuesEmptyOnShards[drainingPriority]["s1"] = false }, + false, + }, + { + "WhenQueueIsNonEmptyOnDrainingShard_ShouldReturnFalse", + drainingPriority, + func(h *fsTestHarness) { h.fs.drainingQueuesEmptyOnShards[drainingPriority]["s3-draining"] = false }, + false, + }, + {"ForActivePriority_ShouldReturnFalse", activePriority, nil, false}, + {"ForUnknownPriority_ShouldReturnFalse", 99, nil, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + h.fs.update(types.FlowSpecification{ID: "f1", Priority: pLow}, h.allShards) + if tc.setup != nil { + tc.setup(h) + } + assert.Equal(t, tc.expectDrained, h.fs.isDrained(tc.priorityToTest, h.allShards)) + }) + } + }) + + t.Run("PurgeShard", func(t *testing.T) { + t.Parallel() + h := newFsTestHarness(t) + // Simulate a priority change to populate both active and draining maps. + h.fs.update(types.FlowSpecification{ID: "f1", Priority: pLow}, h.allShards) + shardToPurge := "s2" + require.Contains(t, h.fs.activeQueueEmptyOnShards, shardToPurge, "Test setup: s2 must be in active map") + require.Contains(t, h.fs.drainingQueuesEmptyOnShards[pHigh], shardToPurge, "Test setup: s2 must be in draining map") + + h.fs.purgeShard(shardToPurge) + + assert.NotContains(t, h.fs.activeQueueEmptyOnShards, shardToPurge, "s2 should be purged from active map") + assert.Contains(t, h.fs.activeQueueEmptyOnShards, "s1", "s1 should remain in active map") + assert.NotContains(t, h.fs.drainingQueuesEmptyOnShards[pHigh], shardToPurge, "s2 should be purged from draining map") + assert.Contains(t, h.fs.drainingQueuesEmptyOnShards[pHigh], "s1", "s1 should remain in draining map") + }) +} diff --git a/pkg/epp/flowcontrol/registry/gc.go b/pkg/epp/flowcontrol/registry/gc.go new file mode 100644 index 000000000..c665044af --- /dev/null +++ b/pkg/epp/flowcontrol/registry/gc.go @@ -0,0 +1,114 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "sync" + "time" + + "k8s.io/utils/clock" +) + +// gcTimer represents a single cancellable garbage collection timer and the generation it is associated with. +type gcTimer struct { + // timer is the active timer instance. We use the abstraction to allow for deterministic testing. + timer clock.Timer + generation uint64 +} + +// gcTracker is a concurrent-safe, decoupled manager for garbage collection timers. +// +// # Role: Decoupled Timer Management +// +// Its sole responsibility is to start, stop, and report timer expirations via a channel. It is explicitly designed to +// have no knowledge of the `FlowRegistry`'s internal state or the meaning of the flows it tracks. This decoupling keeps +// the timer logic simple, thread-safe, and reusable. +// +// # Concurrency and Race Handling +// +// The gcTracker uses a `sync.Mutex` to protect its internal map of timers. +// +// A critical aspect of its design is the handling of the inherent race condition in `time.Timer.Stop()`. A timer might +// fire just as it is being stopped or replaced (`Stop()` returns false if the timer already fired). The `gcTracker` +// addresses this by associating a generation ID with each timer. When a timer fires, it sends this generation ID +// along with the event. The consumer (`FlowRegistry`) is responsible for checking if the generation in the event +// matches the current generation of the flow, thereby ignoring stale timer events. A stale event can occur if a flow +// becomes idle (timer starts, gen=N), then becomes active again, then idle again (timer is replaced, gen=N+1). If the +// original timer for gen=N fires after this sequence, the registry will correctly ignore it by comparing generations. +type gcTracker struct { + mu sync.Mutex + // timers maps a flow ID to its active garbage collection timer. + timers map[string]*gcTimer + // eventCh is the channel over which timer expiration events are sent. + eventCh chan<- event + // clock provides time-related functions. It allows injecting a mock clock for testing. + // We use `WithTickerAndDelayedExecution` as it guarantees the `AfterFunc` method. + clock clock.WithTickerAndDelayedExecution +} + +// newGCTracker creates a new garbage collection timer manager. +// Requires a clock implementation (provided by `FlowRegistry`). +func newGCTracker(eventCh chan<- event, clk clock.WithTickerAndDelayedExecution) *gcTracker { + return &gcTracker{ + timers: make(map[string]*gcTimer), + eventCh: eventCh, + clock: clk, + } +} + +// start begins a new GC timer for a given flow and generation. If a timer already exists for the flow, it is implicitly +// stopped and replaced. This is the desired behavior, as a new call to start a timer for a flow (e.g., because it just +// became idle) should always supersede any previous timer. +func (gc *gcTracker) start(flowID string, generation uint64, timeout time.Duration) { + gc.mu.Lock() + defer gc.mu.Unlock() + + // If a timer already exists for this flow, stop it before starting a new one. This handles cases where a flow might + // flap between active and idle states. + if existing, ok := gc.timers[flowID]; ok { + existing.timer.Stop() + } + + // We use AfterFunc which works efficiently with both `RealClock` and `FakeClock` (for tests). + timer := gc.clock.AfterFunc(timeout, func() { + // When the timer fires, send the event to the registry for processing. + // This happens asynchronously (triggered by time passage or `FakeClock.Step`). + gc.eventCh <- &gcTimerFiredEvent{ + flowID: flowID, + generation: generation, + } + }) + + gc.timers[flowID] = &gcTimer{ + timer: timer, + generation: generation, + } +} + +// stop halts and deletes the timer for a given flow. It is safe to call even if no timer exists for the flow. +func (gc *gcTracker) stop(flowID string) { + gc.mu.Lock() + defer gc.mu.Unlock() + + if existing, ok := gc.timers[flowID]; ok { + // Attempt to stop the timer. If `timer.Stop()` returns false, it means the timer has already fired and its callback + // has been sent to the event channel. The `FlowRegistry`'s event handler is responsible for correctly handling this + // race condition by checking the flow's generation ID. + existing.timer.Stop() + delete(gc.timers, flowID) + } +} diff --git a/pkg/epp/flowcontrol/registry/gc_test.go b/pkg/epp/flowcontrol/registry/gc_test.go new file mode 100644 index 000000000..9bcbedbff --- /dev/null +++ b/pkg/epp/flowcontrol/registry/gc_test.go @@ -0,0 +1,175 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +// gcTestHarness encapsulates all the components needed for a gcTracker test. +type gcTestHarness struct { + t *testing.T + fakeClock *clocktesting.FakeClock + eventCh chan event + gc *gcTracker +} + +// newGCTrackerTestHarness creates a new test harness with a deterministic fake clock. +// The channel can be buffered or unbuffered depending on the test requirements. +func newGCTrackerTestHarness(t *testing.T, eventChanBuffer int) *gcTestHarness { + t.Helper() + eventCh := make(chan event, eventChanBuffer) + fakeClock := clocktesting.NewFakeClock(time.Now()) + gc := newGCTracker(eventCh, fakeClock) + + return &gcTestHarness{ + t: t, + fakeClock: fakeClock, + eventCh: eventCh, + gc: gc, + } +} + +// assertEventReceived checks that a GC event was received on the channel with the expected properties. +func (h *gcTestHarness) assertEventReceived(expectedFlowID string, expectedGen uint64) { + h.t.Helper() + select { + case evt := <-h.eventCh: + timerEvt, ok := evt.(*gcTimerFiredEvent) + require.True(h.t, ok, "Event should be of type gcTimerFiredEvent") + assert.Equal(h.t, expectedFlowID, timerEvt.flowID, "Event should have the correct FlowID") + assert.Equal(h.t, expectedGen, timerEvt.generation, "Event should have the correct Generation") + case <-time.After(1 * time.Second): // Use a real timeout for the test itself to prevent hangs. + h.t.Fatal("Timeout: Did not receive expected GC event") + } +} + +// assertNoEventReceived checks that no GC event was sent on the channel. +func (h *gcTestHarness) assertNoEventReceived(format string, args ...any) { + h.t.Helper() + select { + case evt := <-h.eventCh: + // Combine the static error with the custom message for a comprehensive failure report. + customMsg := fmt.Sprintf(format, args...) + h.t.Fatalf("Received unexpected GC event: %v. Assertion message: %s", evt, customMsg) + default: + // Success, no event received. + } +} + +func TestGCTracker(t *testing.T) { + t.Parallel() + + const timeout = 10 * time.Second + const flowID = "test-flow" + + t.Run("BasicLifecycle", func(t *testing.T) { + t.Parallel() + + t.Run("WhenTimerFires_ShouldSendEvent", func(t *testing.T) { + t.Parallel() + h := newGCTrackerTestHarness(t, 1) + + h.gc.start(flowID, 1, timeout) + h.assertNoEventReceived("No event should be sent before the timeout expires") + + // Advance the clock to fire the timer. + h.fakeClock.Step(timeout) + h.assertEventReceived(flowID, 1) + }) + + t.Run("WhenTimerIsStopped_ShouldNotFire", func(t *testing.T) { + t.Parallel() + h := newGCTrackerTestHarness(t, 1) + + h.gc.start(flowID, 2, timeout) + h.gc.stop(flowID) + + // Advance the clock past the original timeout. + h.fakeClock.Step(timeout) + h.assertNoEventReceived("No event should be received for a stopped timer") + }) + + t.Run("WhenTimerIsReplaced_ShouldSupersedeOldTimer", func(t *testing.T) { + t.Parallel() + h := newGCTrackerTestHarness(t, 1) + + // Start the first timer (gen 3). + h.gc.start(flowID, 3, timeout) + // Start a new, longer timer immediately (gen 4). This should cancel the first one. + h.gc.start(flowID, 4, timeout*3) + + // Advance the clock just enough to fire the first timer. + h.fakeClock.Step(timeout) + h.assertNoEventReceived("The superseded timer (gen 3) should not have fired") + + // Now advance the clock to fire the second, active timer. + h.fakeClock.Step(timeout * 2) + h.assertEventReceived(flowID, 4) + }) + }) + + t.Run("EdgeCasesAndSafety", func(t *testing.T) { + t.Parallel() + + t.Run("WhenStoppingNonExistentTimer_ShouldNotPanic", func(t *testing.T) { + t.Parallel() + h := newGCTrackerTestHarness(t, 1) + assert.NotPanics(t, func() { + h.gc.stop("non-existent-flow") + }, "stop() should not panic for a non-existent flow") + }) + + // This test verifies the behavior of the `gcTracker` when `stop()` is called for a timer that has already fired and + // sent its event. The tracker's contract is that the consumer (`FlowRegistry`) will receive the event, and the + // subsequent `stop()` call must gracefully clean up the internal state. + t.Run("WhenStoppingTimer_ShouldCleanupMapAndPreserveEvent", func(t *testing.T) { + t.Parallel() + // Use a buffered channel so that `Step()` can complete without blocking. + h := newGCTrackerTestHarness(t, 1) + + // Start the timer. + h.gc.start(flowID, 5, timeout) + + // Advance the clock. This fires the timer and sends the event to the buffered channel. + // The `Step()` call returns, and the event is now "in-flight". + h.fakeClock.Step(timeout) + + // At this point, the timer has fired, but it might still exist in the `gcTracker`'s map until its `AfterFunc` + // completes fully. + // Now, call `stop()`. This simulates the case where `stop()` is called for a flow whose timer has already fired + // and sent its event. + h.gc.stop(flowID) + + // Assert that the timer has been removed from the internal map. + // This is the primary responsibility of `stop()`. + h.gc.mu.Lock() + assert.NotContains(t, h.gc.timers, flowID, "stop() should clean up the internal map even for a fired timer") + h.gc.mu.Unlock() + + // Assert that the event from the fired timer is still received. + // This proves `stop()` doesn't interfere with an already-sent event. + h.assertEventReceived(flowID, 5) + }) + }) +} diff --git a/pkg/epp/flowcontrol/registry/lifecycle.go b/pkg/epp/flowcontrol/registry/lifecycle.go new file mode 100644 index 000000000..710b41348 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/lifecycle.go @@ -0,0 +1,192 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "fmt" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" +) + +// ============================================================================= +// Component Lifecycle State Machine +// +// The registry manages stateful components (`managedQueue`, `registryShard`). +// These components follow a well-defined lifecycle. The system is an +// event-driven state machine where changes in state trigger signals, which are +// processed by the central `FlowRegistry` control plane. +// ============================================================================= + +// componentStatus represents the lifecycle state of a managed component (Queue or Shard). +// It is intended to be stored and manipulated using atomic operations (e.g., `atomic.Int32`) to ensure robust, atomic +// state transitions and eliminate inconsistent intermediate states. +type componentStatus int32 + +const ( + // componentStatusActive indicates the component is fully operational and accepting new work. + componentStatusActive componentStatus = iota + + // componentStatusDraining indicates the component is shutting down. It is not accepting new work, but is still + // processing existing work. + componentStatusDraining + + // componentStatusDrained indicates the component has finished draining and is now empty. + // The transition into this state (from `componentStatusDraining`) occurs exactly once via `CompareAndSwap` and + // triggers the corresponding `BecameDrained` signal. This acts as an atomic latch for GC. + componentStatusDrained +) + +func (s componentStatus) String() string { + switch s { + case componentStatusActive: + return "Active" + case componentStatusDraining: + return "Draining" + case componentStatusDrained: + return "Drained" + default: + return fmt.Sprintf("Unknown(%d)", s) + } +} + +// ============================================================================= +// Control Plane Signals and Callbacks +// +// These definitions establish the communication protocol from the data plane components (Queues, Shards) up to the +// control plane (`FlowRegistry`). +// ============================================================================= + +// --- Queue Signals --- + +// queueStateSignal is an enum describing the edge-triggered state change events emitted by a `managedQueue`. +type queueStateSignal int + +const ( + // queueStateSignalBecameEmpty is sent when an Active queue transitions from non-empty to empty. + // Trigger: Len > 0 -> Len == 0. Used for inactivity GC tracking. + queueStateSignalBecameEmpty queueStateSignal = iota + + // queueStateSignalBecameNonEmpty is sent when an Active queue transitions from empty to non-empty. + // Trigger: Len == 0 -> Len > 0. Used for inactivity GC tracking. + queueStateSignalBecameNonEmpty + + // queueStateSignalBecameDrained is sent when a Draining queue transitions to empty. + // Trigger: Transition from `componentStatusDraining` -> `componentStatusDrained`. Used for final GC of the queue + // instance. + queueStateSignalBecameDrained +) + +func (s queueStateSignal) String() string { + switch s { + case queueStateSignalBecameEmpty: + return "QueueBecameEmpty" + case queueStateSignalBecameNonEmpty: + return "QueueBecameNonEmpty" + case queueStateSignalBecameDrained: + return "QueueBecameDrained" + default: + return fmt.Sprintf("Unknown(%d)", s) + } +} + +// signalQueueStateFunc defines the callback function that a `managedQueue` uses to signal lifecycle events to its +// parent shard. +// Implementations MUST NOT block on internal locks or I/O. However, they are expected to block if the `FlowRegistry`'s +// event channel is full; this is required behavior to apply necessary backpressure and ensure reliable event delivery. +type signalQueueStateFunc func(spec types.FlowSpecification, signal queueStateSignal) + +// --- Shard Signals --- + +// shardStateSignal is an enum describing the edge-triggered state change events emitted by a `registryShard`. +type shardStateSignal int + +const ( + // shardStateSignalBecameDrained is sent when a Draining shard transitions to empty. + // Trigger: Transition from `componentStatusDraining` -> componentStatusDrained`. Used for final GC of the shard. + shardStateSignalBecameDrained shardStateSignal = iota +) + +func (s shardStateSignal) String() string { + switch s { + case shardStateSignalBecameDrained: + return "ShardBecameDrained" + default: + return fmt.Sprintf("Unknown(%d)", s) + } +} + +// signalShardStateFunc defines the callback function that a `registryShard` uses to signal its own state changes to the +// parent registry. +// Implementations MUST NOT block on internal locks or I/O. However, they are expected to block if the `FlowRegistry`'s +// event channel is full; this is required behavior to apply necessary backpressure and ensure reliable event delivery. +// Note: The `registryShard` pointer is passed here to facilitate efficient identification and removal during GC. +type signalShardStateFunc func(shard *registryShard, signal shardStateSignal) + +// --- Statistics Propagation --- + +// propagateStatsDeltaFunc defines the callback function that a component uses to propagate its statistics changes +// (deltas) up to its parent (Queue -> Shard -> Registry). +// Implementations MUST be non-blocking (relying on atomics) and must not acquire any locks held by the caller. +type propagateStatsDeltaFunc func(priority uint, lenDelta, byteSizeDelta int64) + +// ============================================================================= +// Control Plane Events (Transport) +// +// These structures define the data transported over the `FlowRegistry`'s event +// channel, carrying the signals defined above or timer expirations to the +// centralized event loop. +// ============================================================================= + +// event is a marker interface for internal state machine events processed by the `FlowRegistry`'s event loop. +type event interface { + isEvent() +} + +// gcTimerFiredEvent is sent when a flow's garbage collection timer expires. +type gcTimerFiredEvent struct { + flowID string + generation uint64 +} + +func (gcTimerFiredEvent) isEvent() {} + +// queueStateChangedEvent is sent when a `managedQueue`'s state changes, carrying a `queueStateSignal`. +type queueStateChangedEvent struct { + shardID string + spec types.FlowSpecification + signal queueStateSignal +} + +func (queueStateChangedEvent) isEvent() {} + +// shardStateChangedEvent is sent when a `registryShard`'s state changes, carrying a `shardStateSignal`. +type shardStateChangedEvent struct { + shard *registryShard + signal shardStateSignal +} + +func (shardStateChangedEvent) isEvent() {} + +// syncEvent is a special event used exclusively for testing to synchronize the test execution with the `FlowRegistry`'s +// event loop. It allows tests to wait until all preceding events have been processed, eliminating the need for polling. +type syncEvent struct { + // doneCh is used by the event loop to signal back to the sender that the sync event (and thus all preceding events) + // has been processed. + doneCh chan struct{} +} + +func (syncEvent) isEvent() {} diff --git a/pkg/epp/flowcontrol/registry/managedqueue.go b/pkg/epp/flowcontrol/registry/managedqueue.go index bea7e7e4f..cc63823aa 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue.go +++ b/pkg/epp/flowcontrol/registry/managedqueue.go @@ -28,53 +28,71 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// parentStatsReconciler defines the callback function that a `managedQueue` uses to propagate its statistics changes up -// to its parent `registryShard`. -type parentStatsReconciler func(lenDelta, byteSizeDelta int64) - -// managedQueue implements `contracts.ManagedQueue`. It is a stateful decorator that wraps a `framework.SafeQueue`, -// augmenting it with two critical, registry-level responsibilities: -// 1. Atomic Statistics: It maintains its own `len` and `byteSize` counters, which are updated atomically. This allows -// the parent `registryShard` to aggregate statistics across many queues without locks. -// 2. Lifecycle Enforcement: It tracks the queue's lifecycle state (active vs. draining) via an `isActive` flag. This -// is crucial for graceful flow updates, as it allows the registry to stop new requests from being enqueued while -// allowing existing items to be drained. +// managedQueueCallbacks groups the callback functions that a `managedQueue` uses to communicate with its parent shard. +type managedQueueCallbacks struct { + // propagateStatsDelta is called to propagate statistics changes (e.g., queue length, byte size) up to the parent. + propagateStatsDelta propagateStatsDeltaFunc + // signalQueueState is called to signal important lifecycle events, such as becoming empty or fully drained, which are + // used by the garbage collector. + signalQueueState signalQueueStateFunc +} + +// managedQueue implements `contracts.ManagedQueue`. It is a stateful decorator that wraps a generic +// `framework.SafeQueue` to enrich it with the state and behaviors required by the registry's control plane. +// +// # Role: The Stateful Decorator +// +// Its responsibilities are centered around augmenting a generic queue implementation with critical registry features. +// It is designed for high-performance, concurrent access on the hot path (Enqueue/Dequeue). +// +// 1. Atomic Statistics (Lock-Free): It maintains its own `len` and `byteSize` counters, which are updated atomically +// using sophisticated lock-free patterns (see `propagateStatsDelta`). This provides O(1), lock-free access for the +// parent `registryShard` to aggregate statistics across many queues. // -// # Statistical Integrity +// 2. Lifecycle Enforcement (Active/Draining/Drained): It tracks the queue's lifecycle state via an atomic `status` +// enum (`componentStatus`). This allows for robust, atomic state transitions and is crucial for graceful flow +// updates (e.g., stopping new requests while allowing existing items to drain). // -// For performance, `managedQueue` maintains its own `len` and `byteSize` fields using atomic operations. This provides -// O(1) access for the parent `registryShard`'s aggregated statistics without needing to lock the underlying queue. +// 3. Exactly-Once Edge Signaling: It uses atomic operations to detect state transitions. Crucially, it ensures that +// signals (like `queueStateSignalBecameEmpty` or `queueStateSignalBecameDrained`) are sent exactly once when the +// transition occurs. The transition from `componentStatusDraining` to `componentStatusDrained` acts as an atomic +// latch for the Drained signal. +// +// # Statistical Integrity and Assumptions // -// This design is predicated on two critical assumptions: // 1. Exclusive Access: All mutating operations on the underlying `framework.SafeQueue` MUST be performed exclusively // through this `managedQueue` wrapper. Direct access to the underlying queue will cause statistical drift. -// 2. In-Process Queue: The `framework.SafeQueue` implementation is an in-process data structure (e.g., a list or -// heap). Its state MUST NOT change through external mechanisms. For example, a queue implementation backed by a -// distributed cache like Redis with its own TTL-based eviction policy would violate this assumption and lead to -// state inconsistency, as items could be removed without notifying the `managedQueue`. // -// This approach avoids the need for the `framework.SafeQueue` interface to return state deltas on each operation, -// keeping its contract simpler. +// 2. In-Process, Stable State: The `framework.SafeQueue` implementation must be an in-process data structure (e.g., a +// list or heap). Its state MUST NOT change through external mechanisms. For example, a queue implementation backed +// by a distributed cache (like Redis) with its own TTL-based eviction policy would violate this assumption and lead +// to state inconsistency, as items could be removed without notifying the `managedQueue`. type managedQueue struct { - // Note: There is no mutex here. Concurrency control is delegated to the underlying `framework.SafeQueue` and the - // atomic operations on the stats fields. - queue framework.SafeQueue - dispatchPolicy framework.IntraFlowDispatchPolicy - flowSpec types.FlowSpecification - byteSize atomic.Uint64 - len atomic.Uint64 - isActive atomic.Bool - reconcileShardStats parentStatsReconciler - logger logr.Logger + // Note: There is no mutex here. Concurrency control is delegated to the underlying `framework.SafeQueue` for queue + // operations, and atomic operations are used for stats and lifecycle status. + queue framework.SafeQueue + dispatchPolicy framework.IntraFlowDispatchPolicy + flowSpec types.FlowSpecification + byteSize atomic.Uint64 + len atomic.Uint64 + + // status tracks the lifecycle state of the queue (Active, Draining, Drained). + // It is stored as an `int32` for atomic operations. + status atomic.Int32 // `componentStatus` + + parentCallbacks managedQueueCallbacks + logger logr.Logger } +var _ contracts.ManagedQueue = &managedQueue{} + // newManagedQueue creates a new instance of a `managedQueue`. func newManagedQueue( queue framework.SafeQueue, dispatchPolicy framework.IntraFlowDispatchPolicy, flowSpec types.FlowSpecification, logger logr.Logger, - reconcileShardStats parentStatsReconciler, + parentCallbacks managedQueueCallbacks, ) *managedQueue { mqLogger := logger.WithName("managed-queue").WithValues( "flowID", flowSpec.ID, @@ -82,42 +100,36 @@ func newManagedQueue( "queueType", queue.Name(), ) mq := &managedQueue{ - queue: queue, - dispatchPolicy: dispatchPolicy, - flowSpec: flowSpec, - reconcileShardStats: reconcileShardStats, - logger: mqLogger, + queue: queue, + dispatchPolicy: dispatchPolicy, + flowSpec: flowSpec, + parentCallbacks: parentCallbacks, + logger: mqLogger, } - mq.isActive.Store(true) + // Initialize the queue in the Active state. + mq.status.Store(int32(componentStatusActive)) return mq } -// markAsDraining is an internal method called by the parent shard to transition this queue to a draining state. -// Once a queue is marked as draining, it will no longer accept new items via `Add`. -func (mq *managedQueue) markAsDraining() { - // Use CompareAndSwap to ensure we only log the transition once. - if mq.isActive.CompareAndSwap(true, false) { - mq.logger.V(logging.DEFAULT).Info("Queue marked as draining") - } -} - -// FlowQueueAccessor returns a new `flowQueueAccessor` instance, which provides a read-only, policy-facing view of the -// queue. +// FlowQueueAccessor returns a read-only, flow-aware view of this queue. +// This accessor is primarily used by policy plugins to inspect the queue's state in a structured way. func (mq *managedQueue) FlowQueueAccessor() framework.FlowQueueAccessor { return &flowQueueAccessor{mq: mq} } -// Add first checks if the queue is active. If it is, it wraps the underlying `framework.SafeQueue.Add` call and -// atomically updates the queue's and the parent shard's statistics. +// Add wraps the underlying `framework.SafeQueue.Add` call. It first checks if the queue is active. If so, it proceeds +// with the addition and atomically updates the queue's and the parent shard's statistics. func (mq *managedQueue) Add(item types.QueueItemAccessor) error { - if !mq.isActive.Load() { - return fmt.Errorf("flow instance %q is not active and cannot accept new requests: %w", - mq.flowSpec.ID, contracts.ErrFlowInstanceNotFound) + // Only StatusActive queues can accept new requests. + status := componentStatus(mq.status.Load()) + if status != componentStatusActive { + return fmt.Errorf("flow instance %q is not active (status: %s) and cannot accept new requests: %w", + mq.flowSpec.ID, status, contracts.ErrFlowInstanceNotFound) } if err := mq.queue.Add(item); err != nil { return err } - mq.reconcileStats(1, int64(item.OriginalRequest().ByteSize())) + mq.propagateStatsDelta(1, int64(item.OriginalRequest().ByteSize())) mq.logger.V(logging.TRACE).Info("Request added to queue", "requestID", item.OriginalRequest().ID()) return nil } @@ -129,7 +141,7 @@ func (mq *managedQueue) Remove(handle types.QueueItemHandle) (types.QueueItemAcc if err != nil { return nil, err } - mq.reconcileStats(-1, -int64(removedItem.OriginalRequest().ByteSize())) + mq.propagateStatsDelta(-1, -int64(removedItem.OriginalRequest().ByteSize())) mq.logger.V(logging.TRACE).Info("Request removed from queue", "requestID", removedItem.OriginalRequest().ID()) return removedItem, nil } @@ -138,59 +150,37 @@ func (mq *managedQueue) Remove(handle types.QueueItemHandle) (types.QueueItemAcc // items. func (mq *managedQueue) Cleanup(predicate framework.PredicateFunc) (cleanedItems []types.QueueItemAccessor, err error) { cleanedItems, err = mq.queue.Cleanup(predicate) - if err != nil || len(cleanedItems) == 0 { - return cleanedItems, err + if err != nil { + return nil, err } - - var lenDelta int64 - var byteSizeDelta int64 - for _, item := range cleanedItems { - lenDelta-- - byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) + if len(cleanedItems) == 0 { + return cleanedItems, nil } - mq.reconcileStats(lenDelta, byteSizeDelta) - mq.logger.V(logging.DEBUG).Info("Cleaned up queue", "removedItemCount", len(cleanedItems), - "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) + mq.propagateStatsDeltaForRemovedItems(cleanedItems) + mq.logger.V(logging.DEBUG).Info("Cleaned up queue", "removedItemCount", len(cleanedItems)) return cleanedItems, nil } // Drain wraps the underlying `framework.SafeQueue.Drain` call and atomically updates statistics for all removed items. func (mq *managedQueue) Drain() ([]types.QueueItemAccessor, error) { drainedItems, err := mq.queue.Drain() - if err != nil || len(drainedItems) == 0 { - return drainedItems, err + if err != nil { + return nil, err } - - var lenDelta int64 - var byteSizeDelta int64 - for _, item := range drainedItems { - lenDelta-- - byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) + if len(drainedItems) == 0 { + return drainedItems, nil } - mq.reconcileStats(lenDelta, byteSizeDelta) - mq.logger.V(logging.DEBUG).Info("Drained queue", "itemCount", len(drainedItems), - "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) + mq.propagateStatsDeltaForRemovedItems(drainedItems) + mq.logger.V(logging.DEBUG).Info("Drained queue", "itemCount", len(drainedItems)) return drainedItems, nil } -// reconcileStats atomically updates the queue's own statistics and calls the parent shard's reconciler to ensure -// aggregated stats remain consistent. -func (mq *managedQueue) reconcileStats(lenDelta, byteSizeDelta int64) { - // The use of Add with a negative number on a Uint64 is the standard Go atomic way to perform subtraction, leveraging - // two's complement arithmetic. - mq.len.Add(uint64(lenDelta)) - mq.byteSize.Add(uint64(byteSizeDelta)) - if mq.reconcileShardStats != nil { - mq.reconcileShardStats(lenDelta, byteSizeDelta) - } -} - -// --- Pass-through and accessor methods --- +// --- Pass-through and Accessor Methods --- -// Name returns the name of the underlying queue implementation. +// Name returns the name of the queue. func (mq *managedQueue) Name() string { return mq.queue.Name() } -// Capabilities returns the capabilities of the underlying queue implementation. +// Capabilities returns the capabilities of the queue. func (mq *managedQueue) Capabilities() []framework.QueueCapability { return mq.queue.Capabilities() } // Len returns the number of items in the queue. @@ -205,13 +195,101 @@ func (mq *managedQueue) PeekHead() (types.QueueItemAccessor, error) { return mq. // PeekTail returns the item at the back of the queue without removing it. func (mq *managedQueue) PeekTail() (types.QueueItemAccessor, error) { return mq.queue.PeekTail() } -// Comparator returns the `framework.ItemComparator` that defines this queue's item ordering logic, as dictated by its -// configured dispatch policy. +// Comparator returns the `framework.ItemComparator` that defines this queue's item ordering logic. func (mq *managedQueue) Comparator() framework.ItemComparator { return mq.dispatchPolicy.Comparator() } -var _ contracts.ManagedQueue = &managedQueue{} +// --- Internal Methods (Called by `registryShard`) --- + +// reactivate is an internal method called by the parent shard to transition this queue from a non-active state +// (Draining or Drained) back to Active. +func (mq *managedQueue) reactivate() { + // Atomically transition the state back to Active. + // We use `Swap` to get the old status for observability. Since this is called under the `FlowRegistry`'s control + // plane lock (via the shard lock), we don't strictly need CompareAndSwap for correctness. + oldStatus := componentStatus(mq.status.Swap(int32(componentStatusActive))) + if oldStatus != componentStatusActive { + // We rely on the `FlowRegistry` to re-evaluate the GC state immediately after the synchronization that caused this + // reactivation. + mq.logger.V(logging.DEFAULT).Info("Queue reactivated", "previousStatus", oldStatus) + } +} -// --- flowQueueAccessor --- +// markAsDraining is an internal method called by the parent shard to transition this queue to a draining state. +// Once draining, it will no longer accept new items via `Add`. +func (mq *managedQueue) markAsDraining() { + // Attempt to transition from Active to Draining atomically. + if mq.status.CompareAndSwap(int32(componentStatusActive), int32(componentStatusDraining)) { + mq.logger.V(logging.DEFAULT).Info("Queue marked as draining") + } + + // CRITICAL: Check if the queue is *already* empty at the moment it's marked as draining (or if it was already + // draining and empty). If so, we must immediately attempt the transition to Drained to ensure timely GC. + // This handles the race where the queue becomes empty just before or during being marked draining. + if mq.Len() == 0 { + // Attempt the transition from Draining to Drained atomically. + if mq.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { + mq.parentCallbacks.signalQueueState(mq.flowSpec, queueStateSignalBecameDrained) + } + } +} + +// propagateStatsDelta atomically updates the queue's own statistics and calls the parent shard's propagator. +// It also implements the core state machine logic for signaling lifecycle events to the control plane, which is +// critical for driving garbage collection. +func (mq *managedQueue) propagateStatsDelta(lenDelta, byteSizeDelta int64) { + // The use of Add with a negative number on a `uint64` is the standard Go atomic way to perform subtraction, + // leveraging two's complement arithmetic. + newLen := mq.len.Add(uint64(lenDelta)) + + // CRITICAL: oldLen is derived *after* the atomic operation. This is a deliberate and non-obvious pattern to prevent a + // race condition. If we were to read the value *before* the `Add` operation (e.g., `oldLen := mq.len.Load()`), two + // concurrent goroutines could both read the same `oldLen` value (e.g., 1) before either of them performs the `Add`. + // If both were decrementing the length, they would both calculate `newLen` as 0 and `oldLen` as 1, causing them both + // to incorrectly signal a transition from non-empty to empty. By deriving `oldLen` from `newLen`, we ensure that only + // the goroutine that actually causes the transition to 0 will see the correct `oldLen` of 1. + oldLen := newLen - uint64(lenDelta) + mq.byteSize.Add(uint64(byteSizeDelta)) + + mq.parentCallbacks.propagateStatsDelta(mq.flowSpec.Priority, lenDelta, byteSizeDelta) + + // --- State Machine Signaling Logic --- + + // Case 1: Check for Draining -> Drained transition (Exactly-Once). + // This must happen if the length just hit zero. + if newLen == 0 { + // Attempt to transition from Draining to Drained atomically. + // This acts as the exactly-once latch. If it succeeds, we are the single goroutine responsible for signaling. + if mq.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { + mq.parentCallbacks.signalQueueState(mq.flowSpec, queueStateSignalBecameDrained) + return // Drained is a terminal state for signaling until reactivation. + } + } + + // Case 2: Standard Active Queue Empty/Non-Empty transitions. + // We only signal these if the queue is currently Active. + // If it's Draining or Drained, these signals are irrelevant (we are waiting for the Drained signal or reactivation). + // We must check the status again here, as it might have changed concurrently (e.g., if it was reactivated). + if componentStatus(mq.status.Load()) == componentStatusActive { + if oldLen > 0 && newLen == 0 { + mq.parentCallbacks.signalQueueState(mq.flowSpec, queueStateSignalBecameEmpty) + } else if oldLen == 0 && newLen > 0 { + mq.parentCallbacks.signalQueueState(mq.flowSpec, queueStateSignalBecameNonEmpty) + } + } +} + +// propagateStatsDeltaForRemovedItems calculates the total stat changes for a slice of removed items and applies them. +func (mq *managedQueue) propagateStatsDeltaForRemovedItems(items []types.QueueItemAccessor) { + var lenDelta int64 + var byteSizeDelta int64 + for _, item := range items { + lenDelta-- + byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) + } + mq.propagateStatsDelta(lenDelta, byteSizeDelta) +} + +// --- `flowQueueAccessor` --- // flowQueueAccessor implements `framework.FlowQueueAccessor`. It provides a read-only, policy-facing view of a // `managedQueue`. @@ -219,6 +297,8 @@ type flowQueueAccessor struct { mq *managedQueue } +var _ framework.FlowQueueAccessor = &flowQueueAccessor{} + // Name returns the name of the queue. func (a *flowQueueAccessor) Name() string { return a.mq.Name() } @@ -242,5 +322,3 @@ func (a *flowQueueAccessor) Comparator() framework.ItemComparator { return a.mq. // FlowSpec returns the `types.FlowSpecification` of the flow this queue accessor is associated with. func (a *flowQueueAccessor) FlowSpec() types.FlowSpecification { return a.mq.flowSpec } - -var _ framework.FlowQueueAccessor = &flowQueueAccessor{} diff --git a/pkg/epp/flowcontrol/registry/managedqueue_test.go b/pkg/epp/flowcontrol/registry/managedqueue_test.go index 6fc172feb..2bb713629 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue_test.go +++ b/pkg/epp/flowcontrol/registry/managedqueue_test.go @@ -17,10 +17,12 @@ limitations under the License. package registry import ( + "context" "errors" "sync" "sync/atomic" "testing" + "time" "github.com/go-logr/logr" "github.com/stretchr/testify/assert" @@ -35,454 +37,949 @@ import ( typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" ) -// testStatsReconciler is a mock implementation of the `parentStatsReconciler` function. -// It captures the deltas it's called with, allowing tests to assert on them. -type testStatsReconciler struct { - mu sync.Mutex - lenDelta int64 - byteSizeDelta int64 - invocationCount int +// mockStatsPropagator is a mock implementation of `propagateStatsDeltaFunc` for capturing stat changes. +// It uses atomics to be safe for concurrent use in stress tests. +type mockStatsPropagator struct { + lenDelta atomic.Int64 + byteSizeDelta atomic.Int64 + invocations atomic.Int64 } -func (r *testStatsReconciler) reconcile(lenDelta, byteSizeDelta int64) { +func (p *mockStatsPropagator) propagate(_ uint, lenDelta, byteSizeDelta int64) { + p.lenDelta.Add(lenDelta) + p.byteSizeDelta.Add(byteSizeDelta) + p.invocations.Add(1) +} + +func (p *mockStatsPropagator) getStats() (lenDelta, byteSizeDelta int64, count int) { + return p.lenDelta.Load(), p.byteSizeDelta.Load(), int(p.invocations.Load()) +} + +// mockManagedQueueSignalRecorder is a thread-safe helper for recording queue state signals. +type mockManagedQueueSignalRecorder struct { + mu sync.Mutex + signals []queueStateSignal +} + +// newMockManagedQueueSignalRecorder initializes the signals slice to be non-nil, preventing assertion failures when +// comparing a nil slice with an empty one. +func newMockManagedQueueSignalRecorder() *mockManagedQueueSignalRecorder { + return &mockManagedQueueSignalRecorder{ + signals: make([]queueStateSignal, 0), + } +} + +func (r *mockManagedQueueSignalRecorder) signal(_ types.FlowSpecification, signal queueStateSignal) { + r.mu.Lock() + defer r.mu.Unlock() + r.signals = append(r.signals, signal) +} + +func (r *mockManagedQueueSignalRecorder) getSignals() []queueStateSignal { r.mu.Lock() defer r.mu.Unlock() - r.lenDelta += lenDelta - r.byteSizeDelta += byteSizeDelta - r.invocationCount++ + // Return a copy to prevent data races if the caller iterates over the slice while new signals are concurrently added + // by the system under test. + signalsCopy := make([]queueStateSignal, len(r.signals)) + copy(signalsCopy, r.signals) + return signalsCopy } -func (r *testStatsReconciler) getStats() (lenDelta, byteSizeDelta int64, count int) { +func (r *mockManagedQueueSignalRecorder) clear() { r.mu.Lock() defer r.mu.Unlock() - return r.lenDelta, r.byteSizeDelta, r.invocationCount + r.signals = make([]queueStateSignal, 0) } -// testFixture holds the components needed for a `managedQueue` test. -type testFixture struct { +// mqTestHarness holds all components for testing a `managedQueue`. +type mqTestHarness struct { + t *testing.T mq *managedQueue mockQueue *frameworkmocks.MockSafeQueue mockPolicy *frameworkmocks.MockIntraFlowDispatchPolicy - reconciler *testStatsReconciler + propagator *mockStatsPropagator + signalRecorder *mockManagedQueueSignalRecorder flowSpec types.FlowSpecification - mockComparator *frameworkmocks.MockItemComparator } -// setupTestManagedQueue creates a new test fixture for testing the `managedQueue`. -func setupTestManagedQueue(t *testing.T) *testFixture { +// newMqTestHarness creates a new test harness. +// The `useRealQueue` flag allows swapping between a mocked `framework.SafeQueue` and a real one for different test +// scenarios. +func newMqTestHarness(t *testing.T, useRealQueue bool) *mqTestHarness { t.Helper() - mockQueue := &frameworkmocks.MockSafeQueue{} - reconciler := &testStatsReconciler{} + propagator := &mockStatsPropagator{} + signalRec := newMockManagedQueueSignalRecorder() flowSpec := types.FlowSpecification{ID: "test-flow", Priority: 1} - mockComparator := &frameworkmocks.MockItemComparator{} mockPolicy := &frameworkmocks.MockIntraFlowDispatchPolicy{ - ComparatorV: mockComparator, + ComparatorV: &frameworkmocks.MockItemComparator{}, + } + + var q framework.SafeQueue + var mockQueue *frameworkmocks.MockSafeQueue + + if useRealQueue { + // Use a real queue implementation for concurrency tests or when behavior is complex. + realQueue, err := queue.NewQueueFromName(listqueue.ListQueueName, nil) + require.NoError(t, err, "Test setup: creating a real listqueue should not fail") + q = realQueue + } else { + // Use a mock queue for unit tests to isolate the `managedQueue`'s logic. + mockQueue = &frameworkmocks.MockSafeQueue{} + q = mockQueue } - mq := newManagedQueue( - mockQueue, - mockPolicy, - flowSpec, - logr.Discard(), - reconciler.reconcile, - ) - require.NotNil(t, mq, "newManagedQueue should not return nil") + callbacks := managedQueueCallbacks{ + propagateStatsDelta: propagator.propagate, + signalQueueState: signalRec.signal, + } + mq := newManagedQueue(q, mockPolicy, flowSpec, logr.Discard(), callbacks) + require.NotNil(t, mq, "Test setup: newManagedQueue should not return nil") - return &testFixture{ + return &mqTestHarness{ + t: t, mq: mq, mockQueue: mockQueue, mockPolicy: mockPolicy, - reconciler: reconciler, + propagator: propagator, + signalRecorder: signalRec, flowSpec: flowSpec, - mockComparator: mockComparator, } } -func TestManagedQueue_New(t *testing.T) { - t.Parallel() - f := setupTestManagedQueue(t) +// addItem is a test helper to add an item to the managed queue. +func (h *mqTestHarness) addItem(size uint64) types.QueueItemAccessor { + h.t.Helper() + item := typesmocks.NewMockQueueItemAccessor(size, "req", h.flowSpec.ID) + require.NoError(h.t, h.mq.Add(item), "addItem helper should successfully add item to queue") + return item +} + +// removeItem is a test helper to remove an item from the managed queue. +func (h *mqTestHarness) removeItem(item types.QueueItemAccessor) { + h.t.Helper() + _, err := h.mq.Remove(item.Handle()) + require.NoError(h.t, err, "removeItem helper should successfully remove item from queue") +} - assert.Zero(t, f.mq.Len(), "A new managedQueue should have a length of 0") - assert.Zero(t, f.mq.ByteSize(), "A new managedQueue should have a byte size of 0") - assert.True(t, f.mq.isActive.Load(), "A new managedQueue should be active") +// assertSignals checks that the recorded signals match the expected sequence. +func (h *mqTestHarness) assertSignals(expected ...queueStateSignal) { + h.t.Helper() + // Ensure nil expected slice is treated as empty for consistent assertions. + if expected == nil { + expected = make([]queueStateSignal, 0) + } + assert.Equal(h.t, expected, h.signalRecorder.getSignals(), "The sequence of emitted GC signals should be correct") +} + +// assertStatus verifies the queue's lifecycle status. +func (h *mqTestHarness) assertStatus(expected componentStatus, msgAndArgs ...interface{}) { + h.t.Helper() + assert.Equal(h.t, expected, componentStatus(h.mq.status.Load()), msgAndArgs...) } -func TestManagedQueue_Add(t *testing.T) { +func TestManagedQueue(t *testing.T) { t.Parallel() - testCases := []struct { - name string - itemByteSize uint64 - mockAddError error - markAsDraining bool - expectError bool - expectedErrorIs error - expectedLen int - expectedByteSize uint64 - expectedLenDelta int64 - expectedByteSizeDelta int64 - expectedReconcile bool - }{ - { - name: "Success", - itemByteSize: 100, - expectError: false, - expectedLen: 1, - expectedByteSize: 100, - expectedLenDelta: 1, - expectedByteSizeDelta: 100, - expectedReconcile: true, - }, - { - name: "Error from underlying queue", - itemByteSize: 100, - mockAddError: errors.New("queue full"), - expectError: true, - expectedLen: 0, - expectedByteSize: 0, - expectedLenDelta: 0, - expectedByteSizeDelta: 0, - expectedReconcile: false, - }, - { - name: "Error on inactive queue", - itemByteSize: 100, - markAsDraining: true, - expectError: true, - expectedErrorIs: contracts.ErrFlowInstanceNotFound, - expectedLen: 0, - expectedByteSize: 0, - expectedLenDelta: 0, - expectedByteSizeDelta: 0, - expectedReconcile: false, - }, - } + t.Run("New_InitialState", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - f := setupTestManagedQueue(t) + assert.Zero(t, harness.mq.Len(), "A new managedQueue should have a length of 0") + assert.Zero(t, harness.mq.ByteSize(), "A new managedQueue should have a byte size of 0") + harness.assertStatus(componentStatusActive, "A new queue should be in the Active state") + }) - // Configure mock - f.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { - return tc.mockAddError - } + t.Run("Add_Scenarios", func(t *testing.T) { + t.Parallel() - if tc.markAsDraining { - f.mq.markAsDraining() - assert.False(t, f.mq.isActive.Load(), "Setup: queue should be marked as inactive") - } + testCases := []struct { + name string + setup func(h *mqTestHarness) + itemByteSize uint64 + expectErr bool + errIs error + expectedLen int + expectedByteSize uint64 + expectedLenDelta int64 + expectedByteSizeDelta int64 + expectPropagateCall bool + }{ + { + name: "WhenQueueIsActive_ShouldSucceed", + setup: func(h *mqTestHarness) { + h.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } + }, + itemByteSize: 100, + expectErr: false, + expectedLen: 1, + expectedByteSize: 100, + expectedLenDelta: 1, + expectedByteSizeDelta: 100, + expectPropagateCall: true, + }, + { + name: "WhenUnderlyingQueueFails_ShouldFail", + setup: func(h *mqTestHarness) { + h.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return errors.New("queue full") } + }, + itemByteSize: 100, + expectErr: true, + expectedLen: 0, + expectedByteSize: 0, + expectedLenDelta: 0, + expectedByteSizeDelta: 0, + expectPropagateCall: false, + }, + { + name: "WhenQueueIsDraining_ShouldFail", + setup: func(h *mqTestHarness) { + h.mq.status.Store(int32(componentStatusDraining)) + }, + itemByteSize: 100, + expectErr: true, + errIs: contracts.ErrFlowInstanceNotFound, + expectedLen: 0, + expectedByteSize: 0, + expectedLenDelta: 0, + expectedByteSizeDelta: 0, + expectPropagateCall: false, + }, + { + name: "WhenQueueIsDrained_ShouldFail", + setup: func(h *mqTestHarness) { + h.mq.status.Store(int32(componentStatusDrained)) + }, + itemByteSize: 100, + expectErr: true, + errIs: contracts.ErrFlowInstanceNotFound, + expectedLen: 0, + expectedByteSize: 0, + expectedLenDelta: 0, + expectedByteSizeDelta: 0, + expectPropagateCall: false, + }, + } - item := typesmocks.NewMockQueueItemAccessor(tc.itemByteSize, "req-1", "test-flow") - err := f.mq.Add(item) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) - if tc.expectError { - require.Error(t, err, "Add should have returned an error") - if tc.expectedErrorIs != nil { - assert.ErrorIs(t, err, tc.expectedErrorIs, "Error should wrap the expected sentinel error") + if tc.setup != nil { + tc.setup(harness) } - } else { - require.NoError(t, err, "Add should not have returned an error") - } - - // Assert final state - assert.Equal(t, tc.expectedLen, f.mq.Len(), "Final length should be as expected") - assert.Equal(t, tc.expectedByteSize, f.mq.ByteSize(), "Final byte size should be as expected") - - // Assert reconciler state - lenDelta, byteSizeDelta, count := f.reconciler.getStats() - assert.Equal(t, tc.expectedLenDelta, lenDelta, "Reconciler length delta should be as expected") - assert.Equal(t, tc.expectedByteSizeDelta, byteSizeDelta, "Reconciler byte size delta should be as expected") - if tc.expectedReconcile { - assert.Equal(t, 1, count, "Reconciler should have been called once") - } else { - assert.Zero(t, count, "Reconciler should not have been called") - } - }) - } -} -func TestManagedQueue_Remove(t *testing.T) { - t.Parallel() - f := setupTestManagedQueue(t) - - // Setup initial state - initialItem := typesmocks.NewMockQueueItemAccessor(100, "req-1", "test-flow") - f.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } - err := f.mq.Add(initialItem) - require.NoError(t, err, "Setup: Adding an item should not fail") - require.Equal(t, 1, f.mq.Len(), "Setup: Length should be 1 after adding an item") - - // --- Test Success --- - t.Run("Success", func(t *testing.T) { - // Configure mock for Remove - f.mockQueue.RemoveFunc = func(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { - return initialItem, nil - } + item := typesmocks.NewMockQueueItemAccessor(tc.itemByteSize, "req-1", "test-flow") + err := harness.mq.Add(item) - // Perform Remove - removedItem, err := f.mq.Remove(initialItem.Handle()) - require.NoError(t, err, "Remove should not return an error") - assert.Equal(t, initialItem, removedItem, "Remove should return the correct item") + if tc.expectErr { + require.Error(t, err, "Add should have returned an error") + if tc.errIs != nil { + assert.ErrorIs(t, err, tc.errIs, "Error should wrap the expected sentinel error") + } + } else { + require.NoError(t, err, "Add should not have returned an error") + } - // Assert final state - assert.Zero(t, f.mq.Len(), "Length should be 0 after removing the only item") - assert.Zero(t, f.mq.ByteSize(), "ByteSize should be 0 after removing the only item") + assert.Equal(t, tc.expectedLen, harness.mq.Len(), "Final length should be as expected") + assert.Equal(t, tc.expectedByteSize, harness.mq.ByteSize(), "Final byte size should be as expected") - // Assert reconciler state - lenDelta, byteSizeDelta, count := f.reconciler.getStats() - assert.Equal(t, int64(0), lenDelta, "Net length delta should be 0 after add and remove") - assert.Equal(t, int64(0), byteSizeDelta, "Net byte size delta should be 0 after add and remove") - assert.Equal(t, 2, count, "Reconciler should have been called for both Add and Remove") + lenDelta, byteSizeDelta, count := harness.propagator.getStats() + assert.Equal(t, tc.expectedLenDelta, lenDelta, "Propagator length delta should be as expected") + assert.Equal(t, tc.expectedByteSizeDelta, byteSizeDelta, "Propagator byte size delta should be as expected") + if tc.expectPropagateCall { + assert.Equal(t, 1, count, "Propagator should have been called exactly once") + } else { + assert.Zero(t, count, "Propagator should not have been called") + } + }) + } }) - // --- Test Error --- - t.Run("Error", func(t *testing.T) { - f := setupTestManagedQueue(t) - require.NoError(t, f.mq.Add(initialItem), "Setup: Adding an item should not fail") + t.Run("Remove_Scenarios", func(t *testing.T) { + t.Parallel() - // Configure mock to return an error - expectedErr := errors.New("item not found") - f.mockQueue.RemoveFunc = func(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { - return nil, expectedErr + item := typesmocks.NewMockQueueItemAccessor(100, "req-1", "test-flow") + + testCases := []struct { + name string + setupMock func(q *frameworkmocks.MockSafeQueue) + expectErr bool + expectedLen int + expectedByteSize uint64 + expectedPropagatorOps int + }{ + { + name: "WhenRemoveSucceeds_ShouldDecrementStats", + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.RemoveFunc = func(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { + return item, nil + } + }, + expectErr: false, + expectedLen: 0, + expectedByteSize: 0, + expectedPropagatorOps: 2, // 1 for add, 1 for remove + }, + { + name: "WhenRemoveFails_ShouldNotChangeStats", + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.RemoveFunc = func(handle types.QueueItemHandle) (types.QueueItemAccessor, error) { + return nil, errors.New("item not found") + } + }, + expectErr: true, + expectedLen: 1, + expectedByteSize: 100, + expectedPropagatorOps: 1, // Only the initial add + }, } - _, err := f.mq.Remove(initialItem.Handle()) - require.ErrorIs(t, err, expectedErr, "Remove should propagate the error") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) - // Assert state did not change - assert.Equal(t, 1, f.mq.Len(), "Length should not change on a failed remove") - assert.Equal(t, uint64(100), f.mq.ByteSize(), "ByteSize should not change on a failed remove") + // Common setup: add one item to the queue. + harness.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } + require.NoError(t, harness.mq.Add(item), "Test setup: Adding an item should not fail") - // Assert reconciler was not called for the remove - _, _, count := f.reconciler.getStats() - assert.Equal(t, 1, count, "Reconciler should only have been called for the initial Add") - }) -} + // Apply test-case specific mock behavior. + tc.setupMock(harness.mockQueue) -func TestManagedQueue_CleanupAndDrain(t *testing.T) { - t.Parallel() - item1 := typesmocks.NewMockQueueItemAccessor(10, "req-1", "test-flow") - item2 := typesmocks.NewMockQueueItemAccessor(20, "req-2", "test-flow") - item3 := typesmocks.NewMockQueueItemAccessor(30, "req-3", "test-flow") + _, err := harness.mq.Remove(item.Handle()) + + if tc.expectErr { + require.Error(t, err, "Remove should have returned an error") + } else { + require.NoError(t, err, "Remove should not have returned an error") + } - // --- Test Cleanup --- - t.Run("Cleanup", func(t *testing.T) { + assert.Equal(t, tc.expectedLen, harness.mq.Len(), "Final length should be as expected") + assert.Equal(t, tc.expectedByteSize, harness.mq.ByteSize(), "Final byte size should be as expected") + + _, _, count := harness.propagator.getStats() + assert.Equal(t, tc.expectedPropagatorOps, count, + "Propagator should have been called the expected number of times") + }) + } + }) + + t.Run("Cleanup_Scenarios", func(t *testing.T) { t.Parallel() - f := setupTestManagedQueue(t) - // Add initial items - require.NoError(t, f.mq.Add(item1), "Setup: Add item1 should not fail") - require.NoError(t, f.mq.Add(item2), "Setup: Add item2 should not fail") - require.NoError(t, f.mq.Add(item3), "Setup: Add item3 should not fail") - require.Equal(t, 3, f.mq.Len(), "Setup: Initial length should be 3") - require.Equal(t, uint64(60), f.mq.ByteSize(), "Setup: Initial byte size should be 60") - - // Configure mock to clean up item2 - f.mockQueue.CleanupFunc = func(p framework.PredicateFunc) ([]types.QueueItemAccessor, error) { - return []types.QueueItemAccessor{item2}, nil + item1 := typesmocks.NewMockQueueItemAccessor(10, "req-1", "test-flow") + item2 := typesmocks.NewMockQueueItemAccessor(20, "req-2", "test-flow") + item3 := typesmocks.NewMockQueueItemAccessor(30, "req-3", "test-flow") + + testCases := []struct { + name string + itemsToAdd []types.QueueItemAccessor + setupMock func(q *frameworkmocks.MockSafeQueue) + expectErr bool + expectedFinalLen int + expectedFinalByteSize uint64 + expectedPropagatorOps int + }{ + { + name: "WhenCleanupSucceeds_ShouldDecrementStats", + itemsToAdd: []types.QueueItemAccessor{item1, item2, item3}, + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.CleanupFunc = func(p framework.PredicateFunc) ([]types.QueueItemAccessor, error) { + // Simulate removing one item (item2) + return []types.QueueItemAccessor{item2}, nil + } + }, + expectErr: false, + expectedFinalLen: 2, + expectedFinalByteSize: 40, // 10 + 30 + expectedPropagatorOps: 4, // 3 adds + 1 cleanup + }, + { + name: "WhenCleanupFails_ShouldNotChangeStats", + itemsToAdd: []types.QueueItemAccessor{item1}, + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.CleanupFunc = func(p framework.PredicateFunc) ([]types.QueueItemAccessor, error) { + return nil, errors.New("internal error") + } + }, + expectErr: true, + expectedFinalLen: 1, + expectedFinalByteSize: 10, + expectedPropagatorOps: 1, // Only the initial add + }, } - cleaned, err := f.mq.Cleanup(func(i types.QueueItemAccessor) bool { return true }) - require.NoError(t, err, "Cleanup should not return an error") - require.Len(t, cleaned, 1, "Cleanup should return one item") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) + + // Common setup: add the specified initial items. + harness.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } + for _, item := range tc.itemsToAdd { + require.NoError(t, harness.mq.Add(item), "Test setup: adding initial items should not fail") + } + + // Apply test-case specific mock behavior. + tc.setupMock(harness.mockQueue) - // Assert final state - assert.Equal(t, 2, f.mq.Len(), "Length should be 2 after cleanup") - assert.Equal(t, uint64(40), f.mq.ByteSize(), "ByteSize should be 40 after cleanup") + _, err := harness.mq.Cleanup(func(i types.QueueItemAccessor) bool { return true }) + + if tc.expectErr { + require.Error(t, err, "Cleanup should have returned an error") + } else { + require.NoError(t, err, "Cleanup should not have returned an error") + } - // Assert reconciler state (3 adds, 1 cleanup) - lenDelta, byteSizeDelta, count := f.reconciler.getStats() - assert.Equal(t, int64(2), lenDelta, "Net length delta should be 2") - assert.Equal(t, int64(40), byteSizeDelta, "Net byte size delta should be 40") - assert.Equal(t, 4, count, "Reconciler should have been called 4 times") + assert.Equal(t, tc.expectedFinalLen, harness.mq.Len(), "Final length should be as expected") + assert.Equal(t, tc.expectedFinalByteSize, harness.mq.ByteSize(), "Final byte size should be as expected") + _, _, count := harness.propagator.getStats() + assert.Equal(t, tc.expectedPropagatorOps, count, + "Propagator should have been called the expected number of times") + }) + } }) - // --- Test Drain --- - t.Run("Drain", func(t *testing.T) { + t.Run("Drain_Scenarios", func(t *testing.T) { t.Parallel() - f := setupTestManagedQueue(t) - // Add initial items - require.NoError(t, f.mq.Add(item1), "Setup: Add item1 should not fail") - require.NoError(t, f.mq.Add(item2), "Setup: Add item2 should not fail") - require.Equal(t, 2, f.mq.Len(), "Setup: Initial length should be 2") - require.Equal(t, uint64(30), f.mq.ByteSize(), "Setup: Initial byte size should be 30") - - // Configure mock to drain both items - f.mockQueue.DrainFunc = func() ([]types.QueueItemAccessor, error) { - return []types.QueueItemAccessor{item1, item2}, nil + item1 := typesmocks.NewMockQueueItemAccessor(10, "req-1", "test-flow") + item2 := typesmocks.NewMockQueueItemAccessor(20, "req-2", "test-flow") + + testCases := []struct { + name string + itemsToAdd []types.QueueItemAccessor + setupMock func(q *frameworkmocks.MockSafeQueue) + expectErr bool + expectedFinalLen int + expectedFinalByteSize uint64 + expectedPropagatorOps int + }{ + { + name: "WhenDrainSucceeds_ShouldDecrementStats", + itemsToAdd: []types.QueueItemAccessor{item1, item2}, + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.DrainFunc = func() ([]types.QueueItemAccessor, error) { + return []types.QueueItemAccessor{item1, item2}, nil + } + }, + expectErr: false, + expectedFinalLen: 0, + expectedFinalByteSize: 0, + expectedPropagatorOps: 3, // 2 adds + 1 drain + }, + { + name: "WhenDrainFails_ShouldNotChangeStats", + itemsToAdd: []types.QueueItemAccessor{item1}, + setupMock: func(q *frameworkmocks.MockSafeQueue) { + q.DrainFunc = func() ([]types.QueueItemAccessor, error) { + return nil, errors.New("internal error") + } + }, + expectErr: true, + expectedFinalLen: 1, + expectedFinalByteSize: 10, + expectedPropagatorOps: 1, // Only the initial add + }, } - drained, err := f.mq.Drain() - require.NoError(t, err, "Drain should not return an error") - require.Len(t, drained, 2, "Drain should return two items") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) - // Assert final state - assert.Zero(t, f.mq.Len(), "Length should be 0 after drain") - assert.Zero(t, f.mq.ByteSize(), "ByteSize should be 0 after drain") + // Common setup: add the specified initial items. + harness.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } + for _, item := range tc.itemsToAdd { + require.NoError(t, harness.mq.Add(item), "Test setup: adding initial items should not fail") + } + + // Apply test-case specific mock behavior. + tc.setupMock(harness.mockQueue) + + _, err := harness.mq.Drain() + + if tc.expectErr { + require.Error(t, err, "Drain should have returned an error") + } else { + require.NoError(t, err, "Drain should not have returned an error") + } - // Assert reconciler state (2 adds, 1 drain) - lenDelta, byteSizeDelta, count := f.reconciler.getStats() - assert.Equal(t, int64(0), lenDelta, "Net length delta should be 0") - assert.Equal(t, int64(0), byteSizeDelta, "Net byte size delta should be 0") - assert.Equal(t, 3, count, "Reconciler should have been called 3 times") + assert.Equal(t, tc.expectedFinalLen, harness.mq.Len(), "Final length should be as expected") + assert.Equal(t, tc.expectedFinalByteSize, harness.mq.ByteSize(), "Final byte size should be as expected") + _, _, count := harness.propagator.getStats() + assert.Equal(t, tc.expectedPropagatorOps, count, + "Propagator should have been called the expected number of times") + }) + } }) - // --- Test Error Paths --- - t.Run("ErrorPaths", func(t *testing.T) { - f := setupTestManagedQueue(t) - require.NoError(t, f.mq.Add(item1), "Setup: Adding an item should not fail") - initialLen, initialByteSize := f.mq.Len(), f.mq.ByteSize() + t.Run("FlowQueueAccessor_ProxiesCalls", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, false) + item := typesmocks.NewMockQueueItemAccessor(100, "req-1", "test-flow") + + harness.mockQueue.PeekHeadV = item + harness.mockQueue.PeekTailV = item + harness.mockQueue.NameV = "MockQueue" + harness.mockQueue.CapabilitiesV = []framework.QueueCapability{framework.CapabilityFIFO} + harness.mockQueue.AddFunc = func(item types.QueueItemAccessor) error { return nil } + + require.NoError(t, harness.mq.Add(item), "Test setup: Adding an item should not fail") + + accessor := harness.mq.FlowQueueAccessor() + require.NotNil(t, accessor, "FlowQueueAccessor should not be nil") + + assert.Equal(t, harness.mq.Name(), accessor.Name(), "Accessor Name() should match managed queue") + assert.Equal(t, harness.mq.Capabilities(), accessor.Capabilities(), + "Accessor Capabilities() should match managed queue") + assert.Equal(t, harness.mq.Len(), accessor.Len(), "Accessor Len() should match managed queue") + assert.Equal(t, harness.mq.ByteSize(), accessor.ByteSize(), "Accessor ByteSize() should match managed queue") + assert.Equal(t, harness.flowSpec, accessor.FlowSpec(), "Accessor FlowSpec() should match managed queue") + assert.Equal(t, harness.mockPolicy.Comparator(), accessor.Comparator(), + "Accessor Comparator() should match the one from the policy") + assert.Equal(t, harness.mockPolicy.Comparator(), harness.mq.Comparator(), + "ManagedQueue Comparator() should also match the one from the policy") + + peekedHead, err := accessor.PeekHead() + require.NoError(t, err, "Accessor PeekHead() should not return an error") + assert.Same(t, item, peekedHead, "Accessor PeekHead() should return the correct item instance") + + peekedTail, err := accessor.PeekTail() + require.NoError(t, err, "Accessor PeekTail() should not return an error") + assert.Same(t, item, peekedTail, "Accessor PeekTail() should return the correct item instance") + }) - expectedErr := errors.New("internal error") + t.Run("StateTransitions_EmitGCSignals", func(t *testing.T) { + t.Parallel() - // Cleanup error - f.mockQueue.CleanupFunc = func(p framework.PredicateFunc) ([]types.QueueItemAccessor, error) { - return nil, expectedErr + type step struct { + action string // "add", "remove", "markDraining" + itemSize uint64 + expectedSignals []queueStateSignal + expectedStatus componentStatus } - _, err := f.mq.Cleanup(func(i types.QueueItemAccessor) bool { return true }) - require.ErrorIs(t, err, expectedErr, "Cleanup should propagate error") - assert.Equal(t, initialLen, f.mq.Len(), "Len should not change on Cleanup error") - assert.Equal(t, initialByteSize, f.mq.ByteSize(), "ByteSize should not change on Cleanup error") - - // Drain error - f.mockQueue.DrainFunc = func() ([]types.QueueItemAccessor, error) { - return nil, expectedErr + + testCases := []struct { + name string + steps []step + expectedLen int + expectedBytes uint64 + expectedStatus componentStatus + }{ + { + name: "WhenAddingToEmptyActiveQueue_ShouldSignalBecameNonEmpty", + steps: []step{ + { + action: "add", + itemSize: 100, + expectedSignals: []queueStateSignal{queueStateSignalBecameNonEmpty}, + expectedStatus: componentStatusActive, + }, + }, + expectedLen: 1, + expectedBytes: 100, + expectedStatus: componentStatusActive, + }, + { + name: "WhenAddingToNonEmptyActiveQueue_ShouldNotSignal", + steps: []step{ + { + action: "add", + itemSize: 100, + expectedSignals: []queueStateSignal{queueStateSignalBecameNonEmpty}, + expectedStatus: componentStatusActive, + }, + { + action: "add", + itemSize: 50, + expectedSignals: nil, + expectedStatus: componentStatusActive, + }, + }, + expectedLen: 2, + expectedBytes: 150, + expectedStatus: componentStatusActive, + }, + { + name: "WhenRemovingLastItemFromActiveQueue_ShouldSignalBecameEmpty", + steps: []step{ + { + action: "add", + itemSize: 100, + expectedSignals: []queueStateSignal{queueStateSignalBecameNonEmpty}, + expectedStatus: componentStatusActive, + }, + { + action: "remove", + expectedSignals: []queueStateSignal{queueStateSignalBecameEmpty}, + expectedStatus: componentStatusActive, + }, + }, + expectedLen: 0, + expectedBytes: 0, + expectedStatus: componentStatusActive, + }, + { + name: "WhenRemovingFromMultiItemActiveQueue_ShouldNotSignal", + steps: []step{ + { + action: "add", + itemSize: 100, + expectedSignals: []queueStateSignal{queueStateSignalBecameNonEmpty}, + expectedStatus: componentStatusActive, + }, + { + action: "add", + itemSize: 50, + expectedSignals: nil, + expectedStatus: componentStatusActive, + }, + { + action: "remove", + expectedSignals: nil, + expectedStatus: componentStatusActive, + }, + }, + expectedLen: 1, + expectedBytes: 100, + expectedStatus: componentStatusActive, + }, + { + name: "WhenRemovingLastItemFromDrainingQueue_ShouldSignalBecameDrained", + steps: []step{ + { + action: "add", + itemSize: 100, + expectedSignals: []queueStateSignal{queueStateSignalBecameNonEmpty}, + expectedStatus: componentStatusActive, + }, + { + action: "markDraining", + expectedSignals: nil, + expectedStatus: componentStatusDraining, + }, + { + action: "remove", + expectedSignals: []queueStateSignal{queueStateSignalBecameDrained}, + expectedStatus: componentStatusDrained, + }, + }, + expectedLen: 0, + expectedBytes: 0, + expectedStatus: componentStatusDrained, + }, + { + name: "WhenMarkingEmptyQueueAsDraining_ShouldSignalBecameDrained", + steps: []step{ + { + action: "markDraining", + expectedSignals: []queueStateSignal{queueStateSignalBecameDrained}, + expectedStatus: componentStatusDrained, + }, + }, + expectedLen: 0, + expectedBytes: 0, + expectedStatus: componentStatusDrained, + }, } - _, err = f.mq.Drain() - require.ErrorIs(t, err, expectedErr, "Drain should propagate error") - assert.Equal(t, initialLen, f.mq.Len(), "Len should not change on Drain error") - assert.Equal(t, initialByteSize, f.mq.ByteSize(), "ByteSize should not change on Drain error") - }) -} -func TestManagedQueue_FlowQueueAccessor(t *testing.T) { - t.Parallel() - f := setupTestManagedQueue(t) - item := typesmocks.NewMockQueueItemAccessor(100, "req-1", "test-flow") - - // Setup underlying queue state - f.mockQueue.PeekHeadV = item - f.mockQueue.PeekTailV = item - f.mockQueue.NameV = "MockQueue" - f.mockQueue.CapabilitiesV = []framework.QueueCapability{framework.CapabilityFIFO} - - // Add an item to populate the managed queue's stats - require.NoError(t, f.mq.Add(item), "Setup: Adding an item should not fail") - - // Get the accessor - accessor := f.mq.FlowQueueAccessor() - require.NotNil(t, accessor, "FlowQueueAccessor should not be nil") - - // Assert that the accessor methods reflect the underlying state - assert.Equal(t, f.mq.Name(), accessor.Name(), "Accessor Name() should match managed queue") - assert.Equal(t, f.mq.Capabilities(), accessor.Capabilities(), "Accessor Capabilities() should match managed queue") - assert.Equal(t, f.mq.Len(), accessor.Len(), "Accessor Len() should match managed queue") - assert.Equal(t, f.mq.ByteSize(), accessor.ByteSize(), "Accessor ByteSize() should match managed queue") - assert.Equal(t, f.flowSpec, accessor.FlowSpec(), "Accessor FlowSpec() should match managed queue") - assert.Equal(t, f.mockComparator, accessor.Comparator(), "Accessor Comparator() should match the one from the policy") - assert.Equal(t, f.mockComparator, f.mq.Comparator(), "ManagedQueue Comparator() should match the one from the policy") - - peekedHead, err := accessor.PeekHead() - require.NoError(t, err, "Accessor PeekHead() should not return an error") - assert.Equal(t, item, peekedHead, "Accessor PeekHead() should return the correct item") - - peekedTail, err := accessor.PeekTail() - require.NoError(t, err, "Accessor PeekTail() should not return an error") - assert.Equal(t, item, peekedTail, "Accessor PeekTail() should return the correct item") -} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + var lastItem types.QueueItemAccessor + + for i, step := range tc.steps { + harness.signalRecorder.clear() + switch step.action { + case "add": + lastItem = harness.addItem(step.itemSize) + case "remove": + require.NotNil(t, lastItem, "Test setup error: cannot remove without a prior add") + harness.removeItem(lastItem) + case "markDraining": + harness.mq.markAsDraining() + default: + t.Fatalf("Unknown test step action: %s", step.action) + } -func TestManagedQueue_Concurrency(t *testing.T) { - t.Parallel() + harness.assertSignals(step.expectedSignals...) + harness.assertStatus(step.expectedStatus, "Step %d (%s): status mismatch", i, step.action) + } - // Use a real `listqueue` since it's concurrent-safe. - lq, err := queue.NewQueueFromName(listqueue.ListQueueName, nil) - require.NoError(t, err, "Setup: creating a real listqueue should not fail") + assert.Equal(t, tc.expectedLen, harness.mq.Len(), "Final queue length mismatch") + assert.Equal(t, tc.expectedBytes, harness.mq.ByteSize(), "Final queue byte size mismatch") + harness.assertStatus(tc.expectedStatus, "Final queue status mismatch") + }) + } + }) - reconciler := &testStatsReconciler{} - flowSpec := types.FlowSpecification{ID: "conc-test-flow", Priority: 1} - mq := newManagedQueue(lq, nil, flowSpec, logr.Discard(), reconciler.reconcile) + t.Run("Lifecycle", func(t *testing.T) { + t.Parallel() - const ( - numGoroutines = 20 - opsPerGoroutine = 200 - itemByteSize = 10 - initialItems = 500 - ) + t.Run("StateTransitions", func(t *testing.T) { + t.Parallel() - var wg sync.WaitGroup - var successfulAdds, successfulRemoves atomic.Int64 + t.Run("NewQueue_ShouldBeActive", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + harness.assertStatus(componentStatusActive) + }) + + t.Run("MarkAsDraining_OnEmptyQueue_ShouldTransitionToDrained", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + harness.mq.markAsDraining() + harness.assertStatus(componentStatusDrained, + "An empty queue should immediately become Drained when marked as such") + }) + + t.Run("Reactivate_OnDrainedQueue_ShouldTransitionToActive", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + harness.mq.status.Store(int32(componentStatusDrained)) // Setup initial state + harness.mq.reactivate() + harness.assertStatus(componentStatusActive, "A Drained queue should become Active after reactivation") + }) + + t.Run("MarkAsDraining_OnNonEmptyQueue_ShouldTransitionToDraining", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + harness.addItem(10) + harness.mq.markAsDraining() + harness.assertStatus(componentStatusDraining, + "A non-empty queue should stay Draining, not become Drained, until empty") + }) + + t.Run("Reactivate_OnDrainingQueue_ShouldTransitionToActive", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + harness.addItem(10) + harness.mq.markAsDraining() + require.Equal(t, componentStatusDraining, componentStatus(harness.mq.status.Load()), + "Test setup: queue should be in Draining state before reactivation") + + harness.mq.reactivate() + + harness.assertStatus(componentStatusActive, "A Draining queue should become Active after reactivation") + assert.Equal(t, 1, harness.mq.Len(), "Queue should still contain its item after reactivation") + }) + }) + + // DrainingRace specifically targets the race condition between marking a queue as draining and the queue + // concurrently becoming empty. This test ensures that the state machine correctly and atomically transitions to + // Drained, sending the signal exactly once. + t.Run("DrainingRace", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + item := harness.addItem(10) + + var wg sync.WaitGroup + wg.Add(2) + + // Goroutine 1: Vigorously attempts to mark the queue as draining. + go func() { + defer wg.Done() + harness.mq.markAsDraining() + }() + + // Goroutine 2: Vigorously attempts to remove the single item. + go func() { + defer wg.Done() + harness.removeItem(item) + }() + + wg.Wait() + + // Verification: + // The core assertion is that no matter which operation "won" the race, the final state is deterministically + // Drained. + // This proves the atomic CAS operations in the state machine are correct. + harness.assertStatus(componentStatusDrained, "Final state must be Drained regardless of the race outcome") + + // We also verify that the correct signal was sent exactly once. The exact sequence of signals can vary depending + // on the race, but the final Drained signal is guaranteed. + signals := harness.signalRecorder.getSignals() + assert.Contains(t, signals, queueStateSignalBecameDrained, "The BecameDrained signal must be sent") + count := 0 + for _, s := range signals { + if s == queueStateSignalBecameDrained { + count++ + } + } + assert.Equal(t, 1, count, "The BecameDrained signal must be sent exactly once") + }) - // Use a channel as a concurrent-safe pool of handles for removal. - handles := make(chan types.QueueItemHandle, initialItems+(numGoroutines*opsPerGoroutine)) + // ActiveFlappingRace targets the race condition of a queue rapidly transitioning between empty and non-empty + // states. This test ensures the `BecameEmpty` and `BecameNonEmpty` signals are sent correctly in strict + // alternation, without duplicates or missed signals. + t.Run("ActiveFlappingRace", func(t *testing.T) { + t.Parallel() - // Pre-fill the queue to give removers something to do immediately. - for range initialItems { - item := typesmocks.NewMockQueueItemAccessor(uint64(itemByteSize), "initial", "flow") - require.NoError(t, mq.Add(item), "Setup: pre-filling queue should not fail") - handles <- item.Handle() - } - // Reset reconciler stats after setup so we only measure the concurrent phase. - *reconciler = testStatsReconciler{} - - // Launch goroutines to perform a mix of concurrent operations. - wg.Add(numGoroutines) - for i := range numGoroutines { - go func(routineID int) { - defer wg.Done() - for j := range opsPerGoroutine { - // Mix up operations between adding and removing. - if (routineID+j)%2 == 0 { - // Add operation - item := typesmocks.NewMockQueueItemAccessor(uint64(itemByteSize), "req", "flow") - if err := mq.Add(item); err == nil { - successfulAdds.Add(1) - handles <- item.Handle() + numGoroutines := 4 + opsPerGoroutine := 100 + if testing.Short() { + t.Log("Running in -short mode, reducing workload.") + numGoroutines = 2 + opsPerGoroutine = 50 + } + const itemByteSize = 10 + + harness := newMqTestHarness(t, true) + // This channel safely passes items from producer goroutines to consumer goroutines. + itemsToProcess := make(chan types.QueueItemAccessor, numGoroutines*opsPerGoroutine) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + var wg sync.WaitGroup + + // Start Adder goroutines. + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + select { + case <-ctx.Done(): + return + default: + item := typesmocks.NewMockQueueItemAccessor(itemByteSize, "req", "flow") + if err := harness.mq.Add(item); err == nil { + itemsToProcess <- item + } + } } - } else { - // Remove operation - select { - case handle := <-handles: - if _, err := mq.Remove(handle); err == nil { - successfulRemoves.Add(1) + }() + } + + // Start Remover goroutines. + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case item := <-itemsToProcess: + _, _ = harness.mq.Remove(item.Handle()) } - default: - // No handles available, do nothing. This can happen if removers are faster than adders. } - } + }() } - }(i) - } - wg.Wait() - // All concurrent operations are complete. Drain any remaining items to get the final count. - drainedItems, err := mq.Drain() - require.NoError(t, err, "Draining the queue at the end should not fail") + wg.Wait() + + // Verification: + // The critical part of this test is to analyze the sequence of signals. + signals := harness.signalRecorder.getSignals() + require.NotEmpty(t, signals, "At least some signals should have been generated") - // Final consistency checks. - finalItemCount := len(drainedItems) - finalByteSize := mq.ByteSize() + // The sequence must be a strict alternation of NonEmpty and Empty signals. + // There should never be two of the same signal in a row. + for i := 0; i < len(signals)-1; i++ { + assert.NotEqual(t, signals[i], signals[i+1], "Signals at index %d and %d must not be duplicates", i, i+1) + } - // The number of items left in the queue should match our tracking. - expectedFinalItemCount := initialItems + int(successfulAdds.Load()) - int(successfulRemoves.Load()) - assert.Equal(t, expectedFinalItemCount, finalItemCount, "Final item count should match initial + adds - removes") + // Depending on the timing, the sequence can start with NonEmpty and end with either. + if signals[0] != queueStateSignalBecameNonEmpty { + assert.Fail(t, "The first signal must be BecameNonEmpty") + } + }) + }) - // The queue's internal stats must be zero after draining. - assert.Zero(t, mq.Len(), "Managed queue length must be zero after drain") - assert.Zero(t, finalByteSize, "Managed queue byte size must be zero after drain") + t.Run("Concurrency_StressTest", func(t *testing.T) { + t.Parallel() + harness := newMqTestHarness(t, true) + + numGoroutines := 20 + opsPerGoroutine := 200 + initialItems := 500 + if testing.Short() { + t.Log("Running in -short mode, reducing workload.") + numGoroutines = 4 + opsPerGoroutine = 50 + initialItems = 100 + } - // The net change reported to the reconciler must match the net change from the concurrent phase, - // plus the final drain. - netLenChangeDuringConcurrentPhase := successfulAdds.Load() - successfulRemoves.Load() - netByteSizeChangeDuringConcurrentPhase := netLenChangeDuringConcurrentPhase * itemByteSize + const itemByteSize = 10 - lenDelta, byteSizeDelta, _ := reconciler.getStats() + var wg sync.WaitGroup + var successfulAdds, successfulRemoves atomic.Int64 - expectedLenDelta := netLenChangeDuringConcurrentPhase - int64(finalItemCount) - expectedByteSizeDelta := netByteSizeChangeDuringConcurrentPhase - int64(uint64(finalItemCount)*itemByteSize) + // This channel safely passes item handles from producer goroutines to consumer goroutines. + handles := make(chan types.QueueItemHandle, initialItems+(numGoroutines*opsPerGoroutine)) - assert.Equal(t, expectedLenDelta, lenDelta, - "Net length delta in reconciler should match the net change from all operations") - assert.Equal(t, expectedByteSizeDelta, byteSizeDelta, - "Net byte size delta in reconciler should match the net change from all operations") + // Pre-fill the queue to ensure removals can happen from the start. + for range initialItems { + item := typesmocks.NewMockQueueItemAccessor(uint64(itemByteSize), "initial", "flow") + require.NoError(t, harness.mq.Add(item), "Test setup: pre-filling queue should not fail") + handles <- item.Handle() + } + // Reset the propagator to only measure the concurrent phase and final drain. + harness.propagator = &mockStatsPropagator{} + harness.mq.parentCallbacks.propagateStatsDelta = harness.propagator.propagate + + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineID int) { + defer wg.Done() + for j := range opsPerGoroutine { + // Alternate between adding and removing items. + if (routineID+j)%2 == 0 { + item := typesmocks.NewMockQueueItemAccessor(uint64(itemByteSize), "req", "flow") + if err := harness.mq.Add(item); err == nil { + successfulAdds.Add(1) + handles <- item.Handle() + } + } else { + select { + case handle := <-handles: + if _, err := harness.mq.Remove(handle); err == nil { + successfulRemoves.Add(1) + } + default: + // No handle available, skip this removal attempt. + } + } + } + }(i) + } + wg.Wait() + + drainedItems, err := harness.mq.Drain() + require.NoError(t, err, "Draining the queue at the end should not fail") + + finalItemCount := len(drainedItems) + + // Core correctness check: The final number of items in the queue must exactly match the initial number, plus all + // successful concurrent additions, minus all successful concurrent removals. + // This proves that no items were lost or duplicated during concurrent operations. + expectedFinalItemCount := initialItems + int(successfulAdds.Load()) - int(successfulRemoves.Load()) + assert.Equal(t, expectedFinalItemCount, finalItemCount, + "Final item count must match initial + adds - removes, proving no item loss") + + // After a successful drain, the managed queue's own counters must be zero. + assert.Zero(t, harness.mq.Len(), "Managed queue length must be zero after drain") + assert.Zero(t, harness.mq.ByteSize(), "Managed queue byte size must be zero after drain") + + // End-to-end statistics check: The net change recorded by the stats propagator must match the net effect of all + // operations (concurrent phase + final drain). + // This validates the atomic stats propagation logic across multiple phases. + netLenChangeDuringConcurrentPhase := successfulAdds.Load() - successfulRemoves.Load() + netByteSizeChangeDuringConcurrentPhase := netLenChangeDuringConcurrentPhase * itemByteSize + + // The final drain operation also propagates a negative delta. + lenDeltaFromDrain := -int64(finalItemCount) + byteSizeDeltaFromDrain := -int64(uint64(finalItemCount) * itemByteSize) + + lenDelta, byteSizeDelta, _ := harness.propagator.getStats() + expectedLenDelta := netLenChangeDuringConcurrentPhase + lenDeltaFromDrain + expectedByteSizeDelta := netByteSizeChangeDuringConcurrentPhase + byteSizeDeltaFromDrain + + assert.Equal(t, expectedLenDelta, lenDelta, + "Net length delta in propagator must match the net change from all operations") + assert.Equal(t, expectedByteSizeDelta, byteSizeDelta, + "Net byte size delta in propagator must match the net change from all operations") + }) } diff --git a/pkg/epp/flowcontrol/registry/registry.go b/pkg/epp/flowcontrol/registry/registry.go new file mode 100644 index 000000000..6b5ecfb35 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/registry.go @@ -0,0 +1,848 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "fmt" + "slices" + "sync" + "sync/atomic" + + "github.com/go-logr/logr" + "k8s.io/utils/clock" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// bandStats holds the aggregated atomic statistics for a single priority band across all shards. +type bandStats struct { + byteSize atomic.Uint64 + len atomic.Uint64 +} + +// FlowRegistry is the concrete implementation of the `contracts.FlowRegistry` interface. It is the top-level +// administrative object and the stateful control plane for the entire flow control system. +// +// # Role: The Central Orchestrator +// +// The `FlowRegistry` is the single source of truth for all configuration and the lifecycle manager for all shards and +// flows. It is responsible for complex, multi-step operations such as flow registration, dynamic shard scaling, and +// coordinating garbage collection across all shards. +// +// # Concurrency: The Serialized Control Plane (Actor Model) +// +// To ensure correctness during complex state transitions, the `FlowRegistry` employs an Actor-like pattern. All +// administrative operations and internal state change events (e.g., GC timers, queue signals) are serialized. +// +// A single background goroutine (the `Run` loop) processes events from the `events` channel. This event loop acquires +// the main `mu` lock before processing any event. External administrative methods (like `RegisterOrUpdateFlow`) also +// acquire this lock. This strict serialization eliminates race conditions in the control plane, simplifying the complex +// logic of the distributed state machine. +// +// (See package documentation in `doc.go` for the full overview of the Architecture and Concurrency Model) +type FlowRegistry struct { + config *Config + logger logr.Logger + + // clock provides the time abstraction for the registry and its components (like `gcTracker`). + clock clock.WithTickerAndDelayedExecution + + // mu protects all administrative operations and the internal state of the registry (shard lists, `flowState`s, + // `nextShardID`). + // It ensures that these complex, multi-step operations appear atomic to the rest of the system. + mu sync.Mutex + + // activeShards contains shards that are currently processing requests (Active). + // We use a slice to maintain deterministic ordering, which is crucial for configuration partitioning. + activeShards []*registryShard + + // drainingShards contains shards that are being gracefully shut down (Draining or Drained). + drainingShards []*registryShard + + // nextShardID is a monotonically increasing counter used to generate unique, stable IDs for shards throughout the + // lifetime of the process. + nextShardID uint64 + + // flowStates tracks the desired state and GC state of all flows, keyed by flow ID. + flowStates map[string]*flowState + + // gc is the decoupled timer manager. + gc *gcTracker + + // events is a channel for all internal state change events from shards and queues. + // This channel drives the serialized event loop (the control plane). + // + // CRITICAL: This is a buffered channel. Sends to this channel MUST NOT be dropped, as the GC system relies on + // exactly-once delivery of edge-triggered events. If this buffer fills, the data path (which sends the events) WILL + // block, applying necessary backpressure to the control plane. + events chan event + + // Globally aggregated statistics. + totalByteSize atomic.Uint64 + totalLen atomic.Uint64 + // perPriorityBandStats stores *bandStats, keyed by priority (uint). + perPriorityBandStats sync.Map +} + +var _ contracts.FlowRegistry = &FlowRegistry{} + +// RegistryOption allows configuring the `FlowRegistry` during initialization using functional options. +type RegistryOption func(*FlowRegistry) + +// WithClock sets the clock abstraction used by the registry (primarily for GC timers). +// This is essential for deterministic testing. If `clk` is nil, the option is ignored. +func WithClock(clk clock.WithTickerAndDelayedExecution) RegistryOption { + return func(fr *FlowRegistry) { + if clk != nil { + fr.clock = clk + } + } +} + +// NewFlowRegistry creates and initializes a new `FlowRegistry` instance. +func NewFlowRegistry( + config *Config, + initialShardCount uint, + logger logr.Logger, + opts ...RegistryOption, +) (*FlowRegistry, error) { + if err := config.validateAndApplyDefaults(); err != nil { + return nil, fmt.Errorf("master configuration is invalid: %w", err) + } + + // Buffered channel to absorb bursts of events. See comment on the struct field for concurrency notes. + events := make(chan event, config.EventChannelBufferSize) + + fr := &FlowRegistry{ + config: config, + logger: logger.WithName("flow-registry"), + flowStates: make(map[string]*flowState), + // gc initialized below after options + events: events, + activeShards: []*registryShard{}, + } + + // Apply functional options (e.g., injecting a `FakeClock`) + for _, opt := range opts { + opt(fr) + } + + // If no clock was provided, default to the real system clock. + if fr.clock == nil { + fr.clock = &clock.RealClock{} + } + + fr.gc = newGCTracker(events, fr.clock) + + for i := range config.PriorityBands { + band := &config.PriorityBands[i] + fr.perPriorityBandStats.Store(band.Priority, &bandStats{}) + } + + // `UpdateShardCount` handles the initial creation and populates `activeShards`. + if err := fr.UpdateShardCount(initialShardCount); err != nil { + return nil, fmt.Errorf("failed to initialize shards: %w", err) + } + + fr.logger.Info("FlowRegistry initialized successfully", "initialShardCount", initialShardCount) + return fr, nil +} + +// Run starts the registry's background event processing loop. It blocks until the provided context is cancelled. +// This loop implements the serialized control plane (Actor model). +func (fr *FlowRegistry) Run(ctx context.Context) { + fr.logger.Info("Starting FlowRegistry event loop") + defer fr.logger.Info("FlowRegistry event loop stopped") + + for { + select { + case <-ctx.Done(): + return + case evt := <-fr.events: + // Acquire the lock to serialize event processing with administrative operations. + fr.mu.Lock() + switch e := evt.(type) { + case *gcTimerFiredEvent: + fr.onGCTimerFired(e) + case *queueStateChangedEvent: + fr.onQueueStateChanged(e) + case *shardStateChangedEvent: + fr.onShardStateChanged(e) + case *syncEvent: + close(e.doneCh) // Synchronization point reached. Acknowledge the caller. + } + fr.mu.Unlock() + } + } +} + +// RegisterOrUpdateFlow handles the registration of a new flow or the update of an existing flow's specification. +// It orchestrates the validation and commit phases atomically across all shards. +// +// Optimization: It employs a "prepare-commit" pattern. The potentially expensive preparation phase (plugin +// instantiation and allocation) is performed outside the main lock to minimize contention. A revalidation step ensures +// consistency if the shard count changes concurrently. +func (fr *FlowRegistry) RegisterOrUpdateFlow(spec types.FlowSpecification) error { + if spec.ID == "" { + return contracts.ErrFlowIDEmpty + } + + // 1. Get a snapshot of the current total shard count under lock. + // We need the count of ALL shards (active + draining) as the update must be applied everywhere. + fr.mu.Lock() + initialTotalShardCount := len(fr.activeShards) + len(fr.drainingShards) + fr.mu.Unlock() + + // 2. Phase 1: Preparation (Outside the lock). + // This involves potentially slow allocations and plugin initialization. + policy, queues, err := fr.buildFlowComponents(spec, initialTotalShardCount) + if err != nil { + return err + } + + // 3. Phase 2: Commit (Inside the lock). + fr.mu.Lock() + defer fr.mu.Unlock() + + currentTotalShardCount := len(fr.activeShards) + len(fr.drainingShards) + + // 4. Revalidation: Check if the shard count changed while we were preparing. + if currentTotalShardCount != initialTotalShardCount { + // Rare race condition: Shard scaling occurred concurrently. + // Instead of re-preparing everything, we optimize by adjusting the existing `queues` slice. + fr.logger.V(logging.DEBUG).Info("Total shard count changed during preparation, adjusting queues under lock", + "flowID", spec.ID, "initialCount", initialTotalShardCount, "currentCount", currentTotalShardCount) + if currentTotalShardCount > initialTotalShardCount { + // Scale-up: Prepare queues only for the new shards. + numNewShards := currentTotalShardCount - initialTotalShardCount + _, newQueues, err := fr.buildFlowComponents(spec, numNewShards) + if err != nil { + return err // Unlikely, but handled defensively. + } + queues = append(queues, newQueues...) + } else { + // Scale-down: Truncate the slice of prepared queues. + // This is safe because scale-down always removes shards from the end of the list, and the `allShards` slice is + // ordered with active shards first. + queues = queues[:currentTotalShardCount] + } + } + + // 5. Apply the changes. This phase is infallible. + fr.applyFlowSynchronizationLocked(spec, policy, queues) + return nil +} + +// UpdateShardCount dynamically adjusts the number of internal state shards. +func (fr *FlowRegistry) UpdateShardCount(n uint) error { + fr.mu.Lock() + defer fr.mu.Unlock() + + targetActiveShards := int(n) + + if targetActiveShards == 0 { + return fmt.Errorf("%w: shard count must be a positive integer, but got %d", contracts.ErrInvalidShardCount, n) + } + + currentActiveShards := len(fr.activeShards) + + if targetActiveShards == currentActiveShards { + return nil + } + + if targetActiveShards > currentActiveShards { + return fr.scaleUpLocked(targetActiveShards) + } + return fr.scaleDownLocked(targetActiveShards) +} + +// Stats returns globally aggregated statistics for the entire `FlowRegistry`. +// +// Note: This method is lock-free as it only reads atomic counters and the configuration. +func (fr *FlowRegistry) Stats() contracts.AggregateStats { + // No lock needed here. We are reading atomics, a `sync.Map`, and the (effectively immutable) config. + + stats := contracts.AggregateStats{ + TotalCapacityBytes: fr.config.MaxBytes, + TotalByteSize: fr.totalByteSize.Load(), + TotalLen: fr.totalLen.Load(), + PerPriorityBandStats: make(map[uint]contracts.PriorityBandStats, len(fr.config.PriorityBands)), + } + + fr.perPriorityBandStats.Range(func(key, value any) bool { + priority := key.(uint) + bandStats := value.(*bandStats) + bandCfg, err := fr.config.getBandConfig(priority) + if err != nil { + // The stats map was populated from the config, so the config must exist for this priority. + fr.logger.Error(err, "Invariant violation: priority band config missing during stats aggregation", + "priority", priority) + return true + } + + stats.PerPriorityBandStats[priority] = contracts.PriorityBandStats{ + Priority: priority, + PriorityName: bandCfg.PriorityName, + CapacityBytes: bandCfg.MaxBytes, + ByteSize: bandStats.byteSize.Load(), + Len: bandStats.len.Load(), + } + return true + }) + + return stats +} + +// getAllShardsLocked returns a combined slice of active and draining shards. +// Active shards always precede draining shards in the returned slice. +// It expects the registry's lock to be held. +func (fr *FlowRegistry) getAllShardsLocked() []*registryShard { + allShards := make([]*registryShard, 0, len(fr.activeShards)+len(fr.drainingShards)) + allShards = append(allShards, fr.activeShards...) + allShards = append(allShards, fr.drainingShards...) + return allShards +} + +// ShardStats returns a slice of statistics, one for each internal shard (active and draining). +func (fr *FlowRegistry) ShardStats() []contracts.ShardStats { + // To minimize lock contention, we acquire the lock just long enough to get the combined list of shards. + // Iterating and gathering stats (which involves reading shard-level atomics/locks) is done outside the critical + // section. + fr.mu.Lock() + allShards := fr.getAllShardsLocked() + fr.mu.Unlock() + + shardStats := make([]contracts.ShardStats, len(allShards)) + for i, s := range allShards { + shardStats[i] = s.Stats() + } + return shardStats +} + +// Shards returns a slice of accessors for all internal shards (active and draining). +// Active shards always precede draining shards in the returned slice. +func (fr *FlowRegistry) Shards() []contracts.RegistryShard { + // Similar to `ShardStats`, minimize lock contention by getting the list under lock. + fr.mu.Lock() + allShards := fr.getAllShardsLocked() + fr.mu.Unlock() + + // This conversion is necessary because the method signature requires a slice of interfaces. + shardContracts := make([]contracts.RegistryShard, len(allShards)) + for i, s := range allShards { + shardContracts[i] = s + } + return shardContracts +} + +// --- Internal Methods --- + +const shardIDFormat = "shard-%d" + +// scaleUpLocked handles adding new shards to the registry. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) scaleUpLocked(newTotalActive int) error { + currentActive := len(fr.activeShards) + numToAdd := newTotalActive - currentActive + + fr.logger.Info("Scaling up shards", "currentActive", currentActive, "newTotalActive", newTotalActive) + + // 1. Create the new shards. + newShards := make([]*registryShard, numToAdd) + for i := 0; i < numToAdd; i++ { + // Shard index is based on its position in the final active list (for partitioning). + shardIndex := currentActive + i + + // Generate a unique, stable ID using the monotonic counter. + shardID := fmt.Sprintf(shardIDFormat, fr.nextShardID) + fr.nextShardID++ + + // Note: The config is partitioned based on the *new* total active count. + partitionedConfig, err := fr.config.partition(shardIndex, newTotalActive) + if err != nil { + return fmt.Errorf("failed to partition config for new shard %s: %w", shardID, err) + } + + callbacks := shardCallbacks{ + propagateStatsDelta: fr.propagateStatsDelta, + signalQueueState: fr.handleQueueStateSignal, + signalShardState: fr.handleShardStateSignal, + } + shard, err := newShard(shardID, partitionedConfig, fr.logger, callbacks) + if err != nil { + return fmt.Errorf("failed to create new shard %s: %w", shardID, err) + } + + // Synchronize all existing flows onto the new shard. + for _, state := range fr.flowStates { + // We only need 1 queue instance for this specific new shard. + policy, queues, err := fr.buildFlowComponents(state.spec, 1) + if err != nil { + // This is unlikely as the flow was already validated, but we handle it defensively. + return fmt.Errorf("failed to prepare synchronization for flow %q on new shard %s: %w", + state.spec.ID, shardID, err) + } + + shard.synchronizeFlow(state.spec, policy, queues[0]) + // Initialize the GC tracking state for the new shard. + state.activeQueueEmptyOnShards[shardID] = true + } + newShards[i] = shard + } + + // 2. Add the new shards to the active list. + fr.activeShards = append(fr.activeShards, newShards...) + + // 3. Re-partition the config for all active shards (old and new). + return fr.repartitionShardConfigsLocked() +} + +// scaleDownLocked handles marking shards for graceful draining and re-partitioning. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) scaleDownLocked(newTotalActive int) error { + currentActive := len(fr.activeShards) + + fr.logger.Info("Scaling down shards", "currentActive", currentActive, "newTotalActive", newTotalActive) + + // Identify the shards to drain. These are the ones at the end of the active list. + // The slice from index `newTotalActive` to the end contains the shards to move. + shardsToDrain := fr.activeShards[newTotalActive:] + + // Update the active list to only include the remaining active shards. + fr.activeShards = fr.activeShards[:newTotalActive] + + // Move them to the draining list and mark them. + fr.drainingShards = append(fr.drainingShards, shardsToDrain...) + for _, shard := range shardsToDrain { + shard.markAsDraining() + } + + // Re-partition the config across the remaining active shards. + return fr.repartitionShardConfigsLocked() +} + +// applyFlowSynchronizationLocked is the "commit" step of `RegisterOrUpdateFlow`. +// It updates the central `flowState` and propagates the changes to all shards. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) applyFlowSynchronizationLocked( + spec types.FlowSpecification, + policy framework.IntraFlowDispatchPolicy, + queues []framework.SafeQueue, +) { + flowID := spec.ID + state, exists := fr.flowStates[flowID] + var oldPriority uint + isPriorityChange := false + + // Get the combined list of all shards for state updates and propagation. + allShards := fr.getAllShardsLocked() + + if !exists { + // This is a new flow. + state = newFlowState(spec, allShards) + fr.flowStates[flowID] = state + } else { + // This is an update to an existing flow. + if state.spec.Priority != spec.Priority { + isPriorityChange = true + oldPriority = state.spec.Priority + } + state.update(spec, allShards) + } + + // Propagate the update to all shards (active and draining). + // The `queues` slice must align exactly with the `allShards` slice. + if len(allShards) != len(queues) { + // This indicates a severe logic error during the prepare/commit phase synchronization (a race in + // `RegisterOrUpdateFlow`). + panic(fmt.Sprintf( + "invariant violation: shard count (%d) and prepared queue count (%d) mismatch during commit for flow %s", + len(allShards), len(queues), flowID)) + } + + for i, shard := range allShards { + shard.synchronizeFlow(spec, policy, queues[i]) + } + + // If this was a priority change, attempt to GC the newly-draining queue immediately if it's already empty. + if isPriorityChange { + fr.garbageCollectDrainedQueueLocked(flowID, oldPriority) + } + + // Always re-evaluate the GC state after any change. + fr.evaluateFlowGCStateLocked(state) + fr.logger.Info("Successfully registered or updated flow", "flowID", spec.ID, "priority", spec.Priority, + "generation", state.generation) +} + +// repartitionShardConfigsLocked updates the partitioned configuration for all active shards. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) repartitionShardConfigsLocked() error { + numActive := len(fr.activeShards) + + for i, shard := range fr.activeShards { + newPartitionedConfig, err := fr.config.partition(i, numActive) + if err != nil { + return fmt.Errorf("failed to re-partition config for active shard %s: %w", shard.id, err) + } + shard.updateConfig(newPartitionedConfig) + } + return nil +} + +// buildFlowComponents centralizes the fallible logic of validating a flow spec, fetching its configuration, and +// instantiating its policy and queue. This constitutes the "validation" phase of `RegisterOrUpdateFlow`. +// This function does not require the registry lock. +func (fr *FlowRegistry) buildFlowComponents( + spec types.FlowSpecification, + numQueues int, +) (framework.IntraFlowDispatchPolicy, []framework.SafeQueue, error) { + bandConfig, err := fr.config.getBandConfig(spec.Priority) + if err != nil { + return nil, nil, err + } + + policy, err := intra.NewPolicyFromName(bandConfig.IntraFlowDispatchPolicy) + if err != nil { + return nil, nil, fmt.Errorf("failed to instantiate intra-flow policy %q for flow %q: %w", + bandConfig.IntraFlowDispatchPolicy, spec.ID, err) + } + + // Perform compatibility check. (This check is also done during config validation, but repeated here defensively). + if err := validateBandCompatibility(*bandConfig); err != nil { + return nil, nil, err + } + + // TODO: Optimization: Consider evolving the plugin framework to allow static capability checks (e.g., + // `queue.CapabilitiesByName(name)`) to avoid instantiation here if the only goal was validation. Currently, we must + // instantiate to check capabilities and to get the required comparator. + + queueInstances := make([]framework.SafeQueue, numQueues) + for i := range numQueues { + q, err := queue.NewQueueFromName(bandConfig.Queue, policy.Comparator()) + if err != nil { + // This would be a critical, unexpected error, as we've already validated the plugins exist. + return nil, nil, fmt.Errorf("failed to instantiate queue %q for flow %q: %w", bandConfig.Queue, spec.ID, err) + } + queueInstances[i] = q + } + + return policy, queueInstances, nil +} + +// --- Garbage Collection and State Verification --- +// This section implements the core logic for the "Trust but Verify" garbage collection pattern. + +// garbageCollectFlowLocked attempts to garbage collect an entire flow (e.g., due to inactivity). +// It implements the complete "Trust but Verify" pattern and is the sole authority on flow deletion. +// It expects the registry's write lock to be held by the caller. +func (fr *FlowRegistry) garbageCollectFlowLocked(flowID string) bool { + // We must re-fetch the state here as the GC logic relies on it. + state, exists := fr.flowStates[flowID] + if !exists { + return false // Flow was already deleted, nothing to do. + } + + // --- "Trust but Verify" Step 1: TRUST (the cache) --- + // Check the eventually consistent `flowState`. If the cached state already reflects activity, we can short-circuit + // without the overhead of the expensive live check across all shards. + if !state.isIdle(fr.activeShards) { + // The flowState already reflects activity. This happens if the activity event was processed just before this GC + // attempt. Abort GC. + return false + } + + // --- "Trust but Verify" Step 2: VERIFY (the ground truth) --- + // The cache suggests the flow is idle. We must perform the definitive live check against the atomic counters on the + // live `managedQueue` instances. This prevents race conditions where a flow is incorrectly GC'd just as it becomes + // active (i.e., the GC trigger is processed before the `QueueBecameNonEmpty` signal). + if !fr.isFlowTrulyIdleLocked(flowID) { + // Race resolved: The flow became active just as the GC was triggered. Abort. + fr.logger.V(logging.DEBUG).Info("GC aborted: Live check revealed flow is active", "flowID", flowID) + return false + } + + // --- Step 3: ACT (Destruction) --- + fr.logger.Info("Garbage collecting inactive flow", "flowID", flowID) + + // Collect all priorities associated with this flow (active and draining). + prioritiesToGC := make(map[uint]struct{}) + prioritiesToGC[state.spec.Priority] = struct{}{} + for priority := range state.drainingQueuesEmptyOnShards { + prioritiesToGC[priority] = struct{}{} + } + + // GC all queue instances for this flow from all shards. + for _, shard := range fr.getAllShardsLocked() { + for priority := range prioritiesToGC { + shard.garbageCollect(flowID, priority) + } + } + + delete(fr.flowStates, flowID) + fr.gc.stop(flowID) // Ensure any timer associated with the flow is stopped. + fr.logger.Info("Successfully garbage collected flow", "flowID", flowID) + return true +} + +// garbageCollectDrainedQueueLocked attempts to remove a specific draining queue instance. +// It implements the complete "Trust but Verify" pattern and is the sole authority on draining queue deletion. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) garbageCollectDrainedQueueLocked(flowID string, priority uint) bool { + // We must re-fetch the state here as the GC logic relies on it. + state, ok := fr.flowStates[flowID] + if !ok { + // Flow might have been GC'd concurrently by the inactivity timer. + return false + } + + // --- "Trust but Verify" Step 1: TRUST (the cache) --- + // Check the eventually consistent `flowState`. If the cached state indicates the draining queue is not yet empty + // globally, we can short-circuit without the overhead of the expensive live check. + if !state.isDrained(priority, fr.getAllShardsLocked()) { + // Not all shards have reported completion yet according to the cache. + return false + } + + // --- "Trust but Verify" Step 2: VERIFY (the ground truth) --- + // The cache suggests the queue is drained globally. We must perform the definitive live check against the atomic + // counters. This prevents race conditions where a draining queue is incorrectly GC'd because the control plane's + // cached state is stale (e.g., a priority update occurs before a `QueueBecameNonEmpty` signal for the old priority + // has been processed). + if !fr.isFlowTrulyDrainedLocked(flowID, priority) { + // Race resolved: An item was enqueued to the draining queue just as the GC was triggered. Abort. + fr.logger.V(logging.DEBUG).Info("Draining queue GC aborted: Live check revealed queue is not empty", + "flowID", flowID, "priority", priority) + return false + } + + // --- Step 3: ACT (Destruction) --- + fr.logger.Info("All shards empty for draining queue, triggering garbage collection", "flowID", flowID, + "priority", priority) + + // GC from all shards (active and draining). + for _, shard := range fr.getAllShardsLocked() { + shard.garbageCollect(flowID, priority) + } + // Remove the tracking state for this specific priority. + delete(state.drainingQueuesEmptyOnShards, priority) + return true +} + +// evaluateFlowGCStateLocked is the single source of truth for deciding whether to start or stop a flow's inactivity GC +// timer. +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) evaluateFlowGCStateLocked(state *flowState) { + // GC evaluation only considers the state on active shards. Draining shards do not count towards activity. + if state.isIdle(fr.activeShards) { + fr.logger.V(logging.DEBUG).Info("Flow is now inactive globally, starting GC timer", "flowID", state.spec.ID, + "timeout", fr.config.FlowGCTimeout, "generation", state.generation) + fr.gc.start(state.spec.ID, state.generation, fr.config.FlowGCTimeout) + } else { + // We stop the timer unconditionally if the flow is active. The `gcTracker` handles the case where no timer is + // running. We do not log this as it happens frequently on the hot path. + fr.gc.stop(state.spec.ID) + } +} + +// isFlowTrulyIdleLocked implements the "Verify" step of the "Trust but Verify" pattern for inactivity GC. +// It performs a synchronous live check of the active queue instances across all active shards by reading their atomic +// counters. +// +// Rationale: The centralized `flowState` is eventually consistent. This check provides the definitive ground truth, +// preventing race conditions where a flow is incorrectly GC'd because an activity event hasn't yet been processed by +// the control plane (e.g., the timer event arrives before the `QueueBecameNonEmpty` signal). +// +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) isFlowTrulyIdleLocked(flowID string) bool { + // Iterate only over active shards, as activity on draining shards doesn't prevent inactivity GC. + for _, shard := range fr.activeShards { + // Note: `shard.ActiveManagedQueue` acquires the shard's `RLock` internally. This is safe because the lock hierarchy + // (`Registry.mu` -> `Shard.mu`) is strictly maintained. + mq, err := shard.ActiveManagedQueue(flowID) + if err != nil { + // If the flow exists in `fr.flowStates` (which it must at this point), it MUST have an active queue instance on + // all active shards, guaranteed by `fr.mu` serialization. + // If we cannot find it, the system state is corrupted. + panic(fmt.Sprintf("invariant violation: active flow %s not found on shard %s during live GC check: %v", + flowID, shard.ID(), err)) + } + + // Check the live atomic length. This is the ground truth. + if mq.Len() > 0 { + // Found a request in the active live queue. The flow is definitively NOT idle. + return false + } + } + // All active queues on all active shards are empty. + return true +} + +// isFlowTrulyDrainedLocked implements the "Verify" step of the "Trust but Verify" pattern for draining queue GC. +// It performs a synchronous live check of a draining queue's instances across all shards (active and draining) by +// reading their atomic counters. +// +// Rationale: This mirrors `isFlowTrulyIdleLocked`. It provides the definitive ground truth, preventing race conditions +// where a draining queue is incorrectly GC'd because the control plane's cached state is stale (e.g., a recent enqueue +// to the draining queue hasn't been processed). +// +// It expects the registry's write lock to be held. +func (fr *FlowRegistry) isFlowTrulyDrainedLocked(flowID string, priority uint) bool { + // We must check all shards (active and draining) because the draining queue instance exists on all of them until + // GC'd. + for _, shard := range fr.getAllShardsLocked() { + // Note: `shard.ManagedQueue` acquires the shard's `RLock` internally. This is safe because the lock hierarchy + // (`Registry.mu` -> `Shard.mu`) is strictly maintained. + mq, err := shard.ManagedQueue(flowID, priority) + if err != nil { + // If the flow is being tracked centrally, a queue instance MUST exist on every shard. + // Receiving an error here, especially ErrFlowInstanceNotFound, indicates a severe state inconsistency. + panic(fmt.Sprintf("invariant violation: unexpected error getting queue %s/%d on shard %s during live GC check: %v", + flowID, priority, shard.ID(), err)) + } + + // Check the live atomic length. This is the ground truth. + if mq.Len() > 0 { + // Found an item, so it's not complete. + return false + } + } + // All instances of this draining queue on all shards are empty. + return true +} + +// --- Event Handling (The Control Plane Loop) --- +// These methods are called by the Run loop and expect the registry's write lock (`fr.mu`) to be held. + +// onGCTimerFired handles a garbage collection timer expiration. +func (fr *FlowRegistry) onGCTimerFired(e *gcTimerFiredEvent) { + state, exists := fr.flowStates[e.flowID] + if !exists { + // Flow was already deleted, nothing to do. + return + } + + // If the generation doesn't match, this is a stale timer (e.g., the flow was updated/re-registered). + if state.generation != e.generation { + fr.logger.V(logging.DEBUG).Info("Ignoring stale GC timer event", "flowID", e.flowID, + "eventGeneration", e.generation, "currentGeneration", state.generation) + // Re-evaluate the flow's current idle state. This is crucial because the flow might still be idle but the previous + // timer was invalidated by an update. This ensures a new, correct timer is started if necessary. + fr.evaluateFlowGCStateLocked(state) + return + } + + // The timer is valid for the current generation. Attempt the garbage collection. + fr.garbageCollectFlowLocked(e.flowID) +} + +// onQueueStateChanged handles a state change signal from a `managedQueue`. +func (fr *FlowRegistry) onQueueStateChanged(e *queueStateChangedEvent) { + state, ok := fr.flowStates[e.spec.ID] + if !ok { + // Flow was likely already garbage collected (e.g., by inactivity timer). + return + } + + // Update the centralized tracking state. + state.handleQueueSignal(e.shardID, e.spec.Priority, e.signal) + + if e.signal == queueStateSignalBecameDrained { + // A draining queue instance signaled completion on one shard. Attempt to GC the entire draining queue globally. + fr.garbageCollectDrainedQueueLocked(e.spec.ID, e.spec.Priority) + } else { + // Active flow GC evaluation (BecameEmpty/BecameNonEmpty) only considers active shards. + // `evaluateFlowGCStateLocked` handles the logic of checking only active shards. + fr.evaluateFlowGCStateLocked(state) + } +} + +// onShardStateChanged handles a state change signal from a registryShard. +func (fr *FlowRegistry) onShardStateChanged(e *shardStateChangedEvent) { + if e.signal == shardStateSignalBecameDrained { + fr.logger.Info("Draining shard is now empty, finalizing garbage collection", "shardID", e.shard.id) + + // 1. CRITICAL: Defensively purge the shard's state from all flows first. + // This prevents memory leaks (stale shard IDs remaining in maps) even if the shard removal below fails or if the + // signal was duplicated (which the atomic state transition should prevent, but we are defensive). + for _, flowState := range fr.flowStates { + flowState.purgeShard(e.shard.id) + } + + // 2. Remove the shard from the draining list. + oldLen := len(fr.drainingShards) + fr.drainingShards = slices.DeleteFunc(fr.drainingShards, func(s *registryShard) bool { + return s == e.shard + }) + + // 3. Check for invariant violation. + if len(fr.drainingShards) == oldLen { + // A shard signaled drained but wasn't in the draining list (e.g., it was active, or the signal was somehow + // processed twice despite the atomic latch). + // The system state is potentially corrupted. + panic(fmt.Sprintf("invariant violation: shard %s not found in draining list during garbage collection", + e.shard.id)) + } + } +} + +// --- Callbacks (Data Plane -> Control Plane Communication) --- + +// handleQueueStateSignal is the callback passed to shards to allow them to signal queue state changes. +// It sends an event to the event channel for serialized processing by the control plane. +func (fr *FlowRegistry) handleQueueStateSignal(shardID string, spec types.FlowSpecification, signal queueStateSignal) { + // This must block if the channel is full. Dropping events would cause state divergence and memory leaks, as the GC + // system is edge-triggered. Blocking provides necessary backpressure from the data plane to the control plane. + fr.events <- &queueStateChangedEvent{ + shardID: shardID, + spec: spec, + signal: signal, + } +} + +// handleShardStateSignal is the callback passed to shards to allow them to signal their own state changes. +func (fr *FlowRegistry) handleShardStateSignal(shard *registryShard, signal shardStateSignal) { + // This must also block (see `handleQueueStateSignal`). + fr.events <- &shardStateChangedEvent{ + shard: shard, + signal: signal, + } +} + +// propagateStatsDelta is the callback passed to shards to allow them to propagate their statistics changes up to the +// global registry. This is lock-free. +func (fr *FlowRegistry) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) { + // This function uses two's complement arithmetic to atomically add or subtract from the unsigned counters. + // Casting a negative `int64` to `uint64` results in its two's complement representation, which, when added, is + // equivalent to subtraction. This is a standard and efficient pattern for atomic updates on unsigned integers. + fr.totalLen.Add(uint64(lenDelta)) + fr.totalByteSize.Add(uint64(byteSizeDelta)) + + if bandVal, ok := fr.perPriorityBandStats.Load(priority); ok { + band := bandVal.(*bandStats) + band.len.Add(uint64(lenDelta)) + band.byteSize.Add(uint64(byteSizeDelta)) + } else { + // Stats are being propagated for a priority that wasn't initialized. + panic(fmt.Sprintf("invariant violation: priority band (%d) stats missing during propagation", priority)) + } +} diff --git a/pkg/epp/flowcontrol/registry/registry_test.go b/pkg/epp/flowcontrol/registry/registry_test.go new file mode 100644 index 000000000..4fd87b8b5 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/registry_test.go @@ -0,0 +1,1193 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/utils/clock" + testclock "k8s.io/utils/clock/testing" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" + frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" + intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" +) + +const ( + // Define standard priorities for testing. + pHigh = uint(10) + pMed = uint(20) + pLow = uint(30) + + // Define timeout for synchronization operations to prevent tests from hanging indefinitely. + syncTimeout = 2 * time.Second +) + +// -- Test Fixture --- + +// registryTestFixture provides a harness for testing the `FlowRegistry`. +type registryTestFixture struct { + t *testing.T + fr *FlowRegistry + ctx context.Context + cancel context.CancelFunc + config *Config + // fakeClock is non-nil if `useFakeClock` is true, allowing deterministic time control. + fakeClock *testclock.FakeClock +} + +// fixtureOptions allows customization of the test fixture. +type fixtureOptions struct { + initialShardCount uint + customConfig *Config + // useFakeClock determines if a `FakeClock` is used (for GC tests) or `RealClock` (for stress/blocking tests). + useFakeClock bool +} + +// newRegistryTestFixture creates and starts a new `FlowRegistry` for testing. +func newRegistryTestFixture(t *testing.T, opts fixtureOptions) *registryTestFixture { + t.Helper() + + if opts.initialShardCount == 0 { + opts.initialShardCount = 1 // Default to 1 shard if not specified. + } + + config := opts.customConfig.deepCopy() + if config == nil { + // Default configuration if none provided. + config = &Config{ + // Use a specific timeout; we control the clock or use real time depending on the test. + FlowGCTimeout: 1 * time.Minute, + // Ensure a reasonable buffer size for most tests, unless overridden. + EventChannelBufferSize: 100, + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High"}, + {Priority: pMed, PriorityName: "Medium"}, + {Priority: pLow, PriorityName: "Low"}, + }, + } + } + + var clk clock.WithTickerAndDelayedExecution + var fakeClock *testclock.FakeClock + + if opts.useFakeClock { + startTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + fakeClock = testclock.NewFakeClock(startTime) + clk = fakeClock + } else { + clk = &clock.RealClock{} + } + + // Initialize the registry. + fr, err := NewFlowRegistry(config, opts.initialShardCount, logr.Discard(), WithClock(clk)) + require.NoError(t, err, "NewFlowRegistry should not fail") + + ctx, cancel := context.WithCancel(context.Background()) + // Use a `WaitGroup` to ensure the registry's `Run` loop stops cleanly. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + fr.Run(ctx) + }() + + t.Cleanup(func() { + cancel() + wg.Wait() + }) + + return ®istryTestFixture{ + t: t, + fr: fr, + ctx: ctx, + cancel: cancel, + config: config, + fakeClock: fakeClock, + } +} + +// synchronizeControlPlane blocks until the `FlowRegistry`'s event loop has processed all events currently in the queue. +// This provides a deterministic synchronization point for tests. +func (f *registryTestFixture) synchronizeControlPlane() { + f.t.Helper() + // Use a short timeout context for the synchronization itself. + ctx, cancel := context.WithTimeout(f.ctx, syncTimeout) + defer cancel() + + doneCh := make(chan struct{}) + evt := &syncEvent{doneCh: doneCh} // syncEvent is defined in `lifecycle.go`. + + // Send the sync event. This might block if the event channel is full. + select { + case f.fr.events <- evt: + // Event sent, now wait for the acknowledgment from the event loop. + select { + case <-doneCh: + // Synchronized successfully. + return + case <-ctx.Done(): + f.t.Fatalf("Timeout or cancellation waiting for FlowRegistry synchronization: %v", ctx.Err()) + } + case <-ctx.Done(): + // This likely means the event channel is full and the context timed out before we could send the event. + f.t.Fatalf("Timeout or cancellation sending sync event (channel likely full): %v", ctx.Err()) + } +} + +// advanceTime moves the fake clock forward. It does NOT synchronize. +func (f *registryTestFixture) advanceTime(d time.Duration) { + f.t.Helper() + require.NotNil(f.t, f.fakeClock, "advanceTime requires the fixture to be initialized with useFakeClock=true") + // Step triggers the timer callbacks (e.g., `AfterFunc` in `gcTracker`), which send events to the channel. + f.fakeClock.Step(d) +} + +// advanceClockToGCTimeout is a convenience helper to step the clock past the configured GC timeout. +func (f *registryTestFixture) advanceClockToGCTimeout() { + f.t.Helper() + // Step slightly past the timeout to ensure the timer fires. + f.advanceTime(f.config.FlowGCTimeout + time.Millisecond) +} + +// --- Fixture Helpers (Shard Access and Manipulation) --- + +// getShardByID retrieves a specific shard by its ID. +func (f *registryTestFixture) getShardByID(id string) contracts.RegistryShard { + f.t.Helper() + shards := f.fr.Shards() + for _, shard := range shards { + if shard.ID() == id { + return shard + } + } + f.t.Fatalf("Shard with ID %s not found", id) + return nil +} + +// getShardByIndex retrieves a specific shard by index. +// NOTE: This relies on the deterministic ordering of `Shards()`, which is guaranteed by the implementation and +// necessary for testing partitioning logic, but should be avoided otherwise. +func (f *registryTestFixture) getShardByIndex(index int) contracts.RegistryShard { + f.t.Helper() + shards := f.fr.Shards() + require.Less(f.t, index, len(shards), "Shard index %d out of bounds (Total Shards: %d)", index, len(shards)) + return shards[index] +} + +// addItem adds an item to a specific flow on a specific shard ID. +func (f *registryTestFixture) addItem( + flowID string, + priority uint, + shardID string, + size uint64, +) types.QueueItemAccessor { + f.t.Helper() + shard := f.getShardByID(shardID) + mq, err := shard.ManagedQueue(flowID, priority) + require.NoError(f.t, err, "Failed to get queue for flow %s on shard %s", flowID, shardID) + + item := mocks.NewMockQueueItemAccessor(size, fmt.Sprintf("req-%s-%s", flowID, shardID), flowID) + // Note: Add might fail if the specific queue is draining, which is expected behavior in some tests. + err = mq.Add(item) + // We check the shard status in the error message for easier debugging if the Add fails unexpectedly. + require.NoError(f.t, err, "Failed to add item to queue on shard %s.", shardID) + return item +} + +// removeItem removes an item from a specific flow and priority on a specific shard ID. +func (f *registryTestFixture) removeItem(flowID string, priority uint, shardID string, item types.QueueItemAccessor) { + f.t.Helper() + internalShard := f.getShardByID(shardID) + + mq, err := internalShard.ManagedQueue(flowID, priority) + require.NoError(f.t, err, "Failed to get queue for flow %s at priority %d on shard %s", flowID, priority, shardID) + + _, err = mq.Remove(item.Handle()) + require.NoError(f.t, err, "Failed to remove item from queue on shard %s", shardID) +} + +// getActiveShardIDs returns the IDs of all currently active shards in deterministic order. +func (f *registryTestFixture) getActiveShardIDs() []string { + f.t.Helper() + var ids []string + // `Shards()` returns shards in a deterministic order (active first, then draining). + for _, shard := range f.fr.Shards() { + if shard.IsActive() { + ids = append(ids, shard.ID()) + } + } + return ids +} + +// --- Assertion Helpers (Synchronous) --- +// These helpers check the current state. They assume the caller has synchronized the control plane if necessary. + +// assertFlowExistsNow verifies that a flow exists and is active at the expected priority across all active shards. +func (f *registryTestFixture) assertFlowExistsNow(flowID string, expectedPriority uint) { + f.t.Helper() + shards := f.fr.Shards() + require.NotEmpty(f.t, shards, "Registry should have shards") + + activeShardCount := 0 + for _, shard := range shards { + if !shard.IsActive() { + continue + } + activeShardCount++ + + mq, err := shard.ActiveManagedQueue(flowID) + if assert.NoError(f.t, err, "Flow %s not active on active shard %s", flowID, shard.ID()) { + assert.Equal(f.t, expectedPriority, mq.FlowQueueAccessor().FlowSpec().Priority, + "Flow %s active at wrong priority on shard %s", flowID, shard.ID()) + } + } + assert.Positive(f.t, activeShardCount, "Registry should have active shards") +} + +// assertFlowDoesNotExistNow verifies that a flow is garbage collected and removed from all shards (active and +// draining). +func (f *registryTestFixture) assertFlowDoesNotExistNow(flowID string) { + f.t.Helper() + shards := f.fr.Shards() + for _, shard := range shards { + // Check both active and draining shards. + _, err := shard.ActiveManagedQueue(flowID) + if assert.Error(f.t, err, "Flow %s should not have an active queue on shard %s", flowID, shard.ID()) { + // Ensure the error is the expected "not found" error. + assert.ErrorIs(f.t, err, contracts.ErrFlowInstanceNotFound, + "Unexpected error when checking for flow existence on shard %s", shard.ID()) + } + } + + // Grey-box check: Ensure it's truly gone from the central tracking map. + f.fr.mu.Lock() + _, exists := f.fr.flowStates[flowID] + f.fr.mu.Unlock() + assert.False(f.t, exists, "Flow %s should not exist in internal flowStates map", flowID) +} + +// assertQueueIsDrainingNow verifies that a specific queue instance exists across all shards and exhibits draining +// behavior. +func (f *registryTestFixture) assertQueueIsDrainingNow(flowID string, priority uint) { + f.t.Helper() + shards := f.fr.Shards() + require.NotEmpty(f.t, shards, "Registry has no shards") + + for _, shard := range shards { + // 1. Verify the queue instance exists at the expected priority. + mq, err := shard.ManagedQueue(flowID, priority) + require.NoError(f.t, err, "Draining queue for flow %s at priority %d not found on shard %s", + flowID, priority, shard.ID()) + + // 2. Verify draining behavior: `Add` should fail. + item := mocks.NewMockQueueItemAccessor(1, "test-drain", flowID) + err = mq.Add(item) + require.Error(f.t, err, "Add to a draining queue should fail on shard %s", shard.ID()) + assert.ErrorIs(f.t, err, contracts.ErrFlowInstanceNotFound, + "Error type mismatch when adding to draining queue on shard %s", shard.ID()) + } +} + +// assertStatsNow verifies that the global registry statistics match the expected values. +func (f *registryTestFixture) assertStatsNow(expectedLen, expectedBytes uint64) { + f.t.Helper() + stats := f.fr.Stats() + assert.Equal(f.t, expectedLen, stats.TotalLen, "Global TotalLen mismatch") + assert.Equal(f.t, expectedBytes, stats.TotalByteSize, "Global TotalByteSize mismatch") +} + +// --- Test Functions: Initialization and Validation --- + +func TestFlowRegistry_InitializationErrors(t *testing.T) { + t.Parallel() + + t.Run("Invalid configuration (no bands)", func(t *testing.T) { + t.Parallel() + invalidConfig := &Config{} // No priority bands is invalid. + _, err := NewFlowRegistry(invalidConfig, 1, logr.Discard()) + assert.Error(t, err, "NewFlowRegistry should fail with an invalid config") + assert.Contains(t, err.Error(), "master configuration is invalid") + }) + + t.Run("Initial shard count is 0", func(t *testing.T) { + t.Parallel() + validConfig := &Config{ + PriorityBands: []PriorityBandConfig{{Priority: 1, PriorityName: "P1"}}, + } + // `UpdateShardCount` is called internally and rejects 0 shards. + _, err := NewFlowRegistry(validConfig, 0, logr.Discard()) + assert.Error(t, err, "NewFlowRegistry should fail if initial shard count is 0") + assert.Contains(t, err.Error(), "failed to initialize shards") + // Check that the specific error from UpdateShardCount bubbles up. + // Note: This relies on the implementation detail of the error wrapping structure. + // If the wrapping changes, this specific check might need adjustment. + assert.ErrorIs(t, err, contracts.ErrInvalidShardCount, "error type mismatch") + }) +} + +// --- Test Functions: Registration and Updates --- + +func TestFlowRegistry_RegisterNewFlow(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "test-flow-new" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + err := f.fr.RegisterOrUpdateFlow(specHigh) + require.NoError(t, err, "Registering a new flow should succeed") + + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) +} + +func TestFlowRegistry_RegisterInvalidFlow(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, useFakeClock: true}) + + t.Run("Empty ID", func(t *testing.T) { + err := f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: "", Priority: pHigh}) + assert.ErrorIs(t, err, contracts.ErrFlowIDEmpty, "Should reject empty FlowID") + }) + + t.Run("Unknown Priority", func(t *testing.T) { + err := f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: "bad-flow", Priority: 999}) + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Should reject unknown priority") + }) +} + +func TestFlowRegistry_Update_NoOp(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "test-flow-noop" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + // 1. Initial registration + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh)) + f.synchronizeControlPlane() + + // Capture initial queue instances for comparison later. + initialQueues := make(map[string]contracts.ManagedQueue) + for _, shard := range f.fr.Shards() { + mq, err := shard.ActiveManagedQueue(flowID) + require.NoError(t, err, "Active queue not found on shard %s before no-op update", shard.ID()) + initialQueues[shard.ID()] = mq + } + + // 2. No-op update + err := f.fr.RegisterOrUpdateFlow(specHigh) + require.NoError(t, err, "No-op update should succeed") + + // 3. Verification + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) + for _, shard := range f.fr.Shards() { + mq, err := shard.ActiveManagedQueue(flowID) + require.NoError(t, err, "Active queue not found on shard %s after no-op update", shard.ID()) + // Crucial check: Ensure the object reference hasn't changed (optimization). + assert.Same(t, initialQueues[shard.ID()], mq, "Queue instance should not change on no-op update for shard %s", + shard.ID()) + } +} + +func TestFlowRegistry_Update_PriorityChange(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "test-flow-prio-change" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + specMed := types.FlowSpecification{ID: flowID, Priority: pMed} + + // 1. Setup: Register at `High` priority and add an item. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh)) + shardID := f.getActiveShardIDs()[0] + f.addItem(flowID, pHigh, shardID, 100) + f.synchronizeControlPlane() + f.assertStatsNow(1, 100) + + // 2. Action: Update priority (`High` -> `Med`). + err := f.fr.RegisterOrUpdateFlow(specMed) + require.NoError(t, err, "Priority update should succeed") + + // 3. Verification: `Med` is active, `High` is draining. + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pMed) + f.assertQueueIsDrainingNow(flowID, pHigh) + // Stats should be unchanged as the item is still in the draining queue. + f.assertStatsNow(1, 100) +} + +func TestFlowRegistry_Update_PriorityChangeReactivation(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, useFakeClock: true}) + const flowID = "test-flow-reactivation" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + specMed := types.FlowSpecification{ID: flowID, Priority: pMed} + shardID := f.getActiveShardIDs()[0] // We only have one shard in this test. + + // 1. Setup: Register `High`, add item, switch to `Med`. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh)) + f.addItem(flowID, pHigh, shardID, 100) + require.NoError(t, f.fr.RegisterOrUpdateFlow(specMed), "Priority update to Med should succeed") + f.synchronizeControlPlane() + + // Verify intermediate state: `High` is draining, `Med` is active. + f.assertQueueIsDrainingNow(flowID, pHigh) + // Capture the draining queue instance for later comparison. + drainingQueueAtHigh, err := f.getShardByID(shardID).ManagedQueue(flowID, pHigh) + require.NoError(t, err, "Failed to get draining queue at High priority") + + // Add item to the currently active queue (`Med`). + f.addItem(flowID, pMed, shardID, 50) + f.synchronizeControlPlane() + f.assertStatsNow(2, 150) + + // 2. Action: Switch back to `High` (`Med` -> `High`). + err = f.fr.RegisterOrUpdateFlow(specHigh) + require.NoError(t, err, "Priority rollback should succeed") + + // 3. Verification: `High` is active, `Med` is draining. + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) + f.assertQueueIsDrainingNow(flowID, pMed) + + // Crucial check: The queue instance at High priority should be the SAME object that was previously draining. + activeQueueAtHigh, err := f.getShardByID(shardID).ActiveManagedQueue(flowID) + require.NoError(t, err, "Failed to get active queue at High priority after reactivation") + assert.Same(t, drainingQueueAtHigh, activeQueueAtHigh, + "The previously draining queue instance should be reactivated (optimization check)") +} + +// TestFlowRegistry_RegistrationAtomicityOnFailure verifies that a failed registration (e.g., due to plugin failure) +// does not leave the registry in an inconsistent state. +func TestFlowRegistry_RegistrationAtomicityOnFailure(t *testing.T) { + const pFail = uint(99) + const flowID = "atomic-flow" + policyName := intra.RegisteredPolicyName("mutable-policy-for-atomicity-test") + + // 1. Create a mock policy that is initially valid. + mockPolicy := &frameworkmocks.MockIntraFlowDispatchPolicy{ + RequiredQueueCapabilitiesV: []framework.QueueCapability{}, // Initially, no requirements. + } + + // 2. Register the mock policy via a closure. + intra.MustRegisterPolicy(policyName, func() (framework.IntraFlowDispatchPolicy, error) { + return mockPolicy, nil + }) + + // 3. Start with a valid configuration that uses the initially valid mock policy. + config := &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High"}, + {Priority: pFail, PriorityName: "Mutable-Policy-Band", IntraFlowDispatchPolicy: policyName}, + }, + } + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, customConfig: config, useFakeClock: true}) + + // 4. AFTER initialization, mutate the mock policy to make it invalid. + mockPolicy.RequiredQueueCapabilitiesV = []framework.QueueCapability{"impossible-capability"} + + shardID := f.getActiveShardIDs()[0] + + // 5. Setup: Register a flow at a valid priority. + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Initial valid registration should succeed") + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) + + // 6. Action: Attempt to update the flow to the now-failing priority. + // This will fail during `prepareFlowSynchronization` due to the impossible capability requirement. + specFail := types.FlowSpecification{ID: flowID, Priority: pFail} + err := f.fr.RegisterOrUpdateFlow(specFail) + + // 7. Verification: Ensure failure occurred and the state remains unchanged. + require.Error(t, err, "Registration should fail when policy is incompatible") + assert.ErrorIs(t, err, contracts.ErrPolicyQueueIncompatible, "Error should be due to incompatibility") + + f.synchronizeControlPlane() + // The flow should still be active at the original priority (`pHigh`). + f.assertFlowExistsNow(flowID, pHigh) + + // No artifacts should exist at the failed priority (`pFail`). + _, err = f.getShardByID(shardID).ManagedQueue(flowID, pFail) + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, + "No queue instance should exist at the failed priority (pFail)") +} + +// --- Test Functions: Garbage Collection --- + +func TestFlowRegistry_GC_IdleFlow(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "gc-idle" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + // 1. Register an idle flow (starts GC timer). + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Registering flow should succeed") + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) + + // 2. Advance time past the timeout. + f.advanceClockToGCTimeout() + + // 3. Verify GC occurred. + f.synchronizeControlPlane() + f.assertFlowDoesNotExistNow(flowID) +} + +func TestFlowRegistry_GC_ActiveFlowNotCollected(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "gc-active" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + // 1. Register flow (starts GC timer). + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Registering flow should succeed") + f.synchronizeControlPlane() + + // 2. Advance time partially. + f.advanceTime(f.config.FlowGCTimeout / 2) + + // 3. Make the flow active (stops GC timer). + shardID := f.getActiveShardIDs()[0] + f.addItem(flowID, pHigh, shardID, 100) + f.synchronizeControlPlane() + f.assertStatsNow(1, 100) + + // 4. Advance time past the original timeout. + f.advanceClockToGCTimeout() + + // 5. Verify flow still exists. + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) +} + +func TestFlowRegistry_GC_TimerRestartsWhenIdleAgain(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, useFakeClock: true}) + const flowID = "gc-restart" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + shardID := f.getActiveShardIDs()[0] // We only have one shard in this test. + + // 1. Setup: Active flow. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Registering flow should succeed") + item := f.addItem(flowID, pHigh, shardID, 100) + f.synchronizeControlPlane() + + // 2. Make flow idle (restarts GC timer). + f.removeItem(flowID, pHigh, shardID, item) + f.synchronizeControlPlane() + f.assertStatsNow(0, 0) + + // 3. Advance time partially and verify it still exists. + f.advanceTime(f.config.FlowGCTimeout / 2) + f.synchronizeControlPlane() + f.assertFlowExistsNow(flowID, pHigh) + + // 4. Advance time past timeout and verify GC. + f.advanceClockToGCTimeout() + f.synchronizeControlPlane() + f.assertFlowDoesNotExistNow(flowID) +} + +// TestFlowRegistry_GC_GenerationHandlesStaleTimers verifies that updates (which increment generation) correctly +// invalidate previous timers and start new ones, even if the flow remains idle. +func TestFlowRegistry_GC_GenerationHandlesStaleTimers(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, useFakeClock: true}) + const flowID = "gc-generation" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + // 1. Register flow (Gen 1 timer starts). + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Registering flow should succeed") + f.synchronizeControlPlane() + + // 2. Advance time partially. + f.advanceTime(f.config.FlowGCTimeout / 2) + + // 3. Update flow (Gen 1 timer stopped, Gen 2 timer starts). + // A no-op update still increments the generation and resets the timer. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "No-op update should succeed") + f.synchronizeControlPlane() + + // 4. Advance time past the *original* timeout (Gen 1 timer would fire here if not stopped). + f.advanceTime(f.config.FlowGCTimeout/2 + time.Second*2) + f.synchronizeControlPlane() + + // 5. Verify flow still exists because the Gen 2 timer is still running. + f.assertFlowExistsNow(flowID, pHigh) + + // 6. Advance time past the *new* timeout (Gen 2 timer fires). + f.advanceClockToGCTimeout() + f.synchronizeControlPlane() + f.assertFlowDoesNotExistNow(flowID) +} + +// TestFlowRegistry_GCRace_TimerFiresBeforeActivityProcessed tests the critical race condition where a flow becomes +// active just as its GC timer fires. The timer event might be processed before the activity event. +func TestFlowRegistry_GCRace_TimerFiresBeforeActivityProcessed(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, useFakeClock: true}) + const flowID = "gc-race-activity" + spec := types.FlowSpecification{ID: flowID, Priority: pHigh} + shardID := f.getActiveShardIDs()[0] // We only have one shard in this test. + + // 1. Register flow (starts GC timer). + require.NoError(t, f.fr.RegisterOrUpdateFlow(spec), "Registering flow should succeed") + f.synchronizeControlPlane() + + // 2. Advance clock. This fires the timer and queues the `gcTimerFiredEvent`. + // The event is now sitting in the channel, waiting to be processed. + f.advanceClockToGCTimeout() + + // 3. Add an item. This queues the `queueStateSignalBecameNonEmpty` event BEHIND the timer event. + f.addItem(flowID, pHigh, shardID, 100) + + // 4. Process events. The control loop processes the timer event first. + // The implementation of `onGCTimerFired` must correctly identify that the flow is no longer idle. + f.synchronizeControlPlane() + + // 5. Verify the flow was NOT GC'd and stats are correct. + f.assertFlowExistsNow(flowID, pHigh) + f.assertStatsNow(1, 100) +} + +// TestFlowRegistry_GC_DrainingQueue verifies that a draining queue (due to priority change) is only +// garbage collected when it is empty across ALL shards. +func TestFlowRegistry_GC_DrainingQueue(t *testing.T) { + t.Parallel() + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, useFakeClock: true}) + const flowID = "gc-draining" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + specMed := types.FlowSpecification{ID: flowID, Priority: pMed} + + shardIDs := f.getActiveShardIDs() + require.Len(t, shardIDs, 2) + shard0ID, shard1ID := shardIDs[0], shardIDs[1] + + // 1. Setup: Register `High`, add items to both shards. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh)) + item0 := f.addItem(flowID, pHigh, shard0ID, 100) + item1 := f.addItem(flowID, pHigh, shard1ID, 50) + f.synchronizeControlPlane() + f.assertStatsNow(2, 150) + + // 2. Action: Change priority (`High` -> `Med`). `High` is now draining. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specMed), "Priority update to Med should succeed") + f.synchronizeControlPlane() + f.assertQueueIsDrainingNow(flowID, pHigh) + + // 3. Drain Shard 0. + f.removeItem(flowID, pHigh, shard0ID, item0) + f.synchronizeControlPlane() + f.assertStatsNow(1, 50) + + // 4. Verify: Queue `P_High` should still exist globally because Shard 1 is not empty. + _, err := f.getShardByID(shard0ID).ManagedQueue(flowID, pHigh) + assert.NoError(t, err, "Draining queue should still exist on shard 0 (even if empty locally)") + _, err = f.getShardByID(shard1ID).ManagedQueue(flowID, pHigh) + assert.NoError(t, err, "Draining queue should still exist on shard 1") + + // 5. Drain Shard 1. + f.removeItem(flowID, pHigh, shard1ID, item1) + f.synchronizeControlPlane() + f.assertStatsNow(0, 0) + + // 6. Verify: Queue `P_High` should be garbage collected globally. + for _, shardID := range shardIDs { + _, err := f.getShardByID(shardID).ManagedQueue(flowID, pHigh) + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, + "Draining queue at pHigh was not garbage collected on shard %s", shardID) + } + + // `P_Med` should still be active. + f.assertFlowExistsNow(flowID, pMed) +} + +// --- Test Functions: Sharding and Scaling --- + +// Helper config for sharding tests with defined capacities. +func newShardingTestConfig() *Config { + return &Config{ + MaxBytes: 100, + FlowGCTimeout: 1 * time.Minute, + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High", MaxBytes: 50}, + }, + } +} + +func TestFlowRegistry_Sharding_InitializationPartitioning(t *testing.T) { + t.Parallel() + // Total 100, Band 50. 3 Shards. + // Global: 100/3 = 33, Rem 1. Shards get 34, 33, 33. + // Band: 50/3 = 16, Rem 2. Shards get 17, 17, 16. + opts := fixtureOptions{ + customConfig: newShardingTestConfig(), + initialShardCount: 3, + useFakeClock: true, + } + f := newRegistryTestFixture(t, opts) + require.Len(t, f.fr.Shards(), 3) + + // We rely on the deterministic ordering returned by `Shards()` for partitioning validation. + s0 := f.getShardByIndex(0).Stats() + assert.Equal(t, uint64(34), s0.TotalCapacityBytes, "Shard 0 Global Capacity") + assert.Equal(t, uint64(17), s0.PerPriorityBandStats[pHigh].CapacityBytes, "Shard 0 Band Capacity") + + s1 := f.getShardByIndex(1).Stats() + assert.Equal(t, uint64(33), s1.TotalCapacityBytes, "Shard 1 Global Capacity") + assert.Equal(t, uint64(17), s1.PerPriorityBandStats[pHigh].CapacityBytes, "Shard 1 Band Capacity") + + s2 := f.getShardByIndex(2).Stats() + assert.Equal(t, uint64(33), s2.TotalCapacityBytes, "Shard 2 Global Capacity") + assert.Equal(t, uint64(16), s2.PerPriorityBandStats[pHigh].CapacityBytes, "Shard 2 Band Capacity") +} + +func TestFlowRegistry_Sharding_ScaleUp(t *testing.T) { + t.Parallel() + // Start with 1, scale to 3. + opts := fixtureOptions{ + customConfig: newShardingTestConfig(), + initialShardCount: 1, + useFakeClock: true, + } + f := newRegistryTestFixture(t, opts) + const flowID = "scale-up-flow" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + // Register a flow before scaling to ensure it propagates to new shards. + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh)) + + // Action: Scale Up (1 -> 3). + err := f.fr.UpdateShardCount(3) + require.NoError(t, err, "Scaling up should succeed") + f.synchronizeControlPlane() + + // Verification. + shards := f.fr.Shards() + require.Len(t, shards, 3) + + // Check activity status and partitioning (re-partitioned to 34, 33, 33). + for i, shard := range shards { + assert.True(t, shard.IsActive(), "Shard %d (%s) should be active after scale up", i, shard.ID()) + } + assert.Equal(t, uint64(34), f.getShardByIndex(0).Stats().TotalCapacityBytes, "Shard 0 Global Capacity") + assert.Equal(t, uint64(33), f.getShardByIndex(1).Stats().TotalCapacityBytes, "Shard 1 Global Capacity") + assert.Equal(t, uint64(33), f.getShardByIndex(2).Stats().TotalCapacityBytes, "Shard 2 Global Capacity") + + // Ensure the existing flow was synchronized to the new shards. + f.assertFlowExistsNow(flowID, pHigh) +} + +func TestFlowRegistry_Sharding_ScaleDown_Draining(t *testing.T) { + t.Parallel() + // Start with 3, scale to 1. + opts := fixtureOptions{ + customConfig: newShardingTestConfig(), + initialShardCount: 3, + useFakeClock: true, + } + f := newRegistryTestFixture(t, opts) + const flowID = "scale-down-flow" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Initial registration should succeed") + f.synchronizeControlPlane() + + // Identify shards. We rely on the implementation detail that the *last* shards in the slice are drained. + shards := f.fr.Shards() + require.Len(t, shards, 3) + activeShardID := shards[0].ID() + drainingShard1ID := shards[1].ID() + drainingShard2ID := shards[2].ID() + + // Add items specifically to the shards destined for removal. + // This prevents them from being immediately GC'd upon scale-down, allowing us to verify the draining state. + f.addItem(flowID, pHigh, drainingShard1ID, 10) + f.addItem(flowID, pHigh, drainingShard2ID, 10) + f.synchronizeControlPlane() + f.assertStatsNow(2, 20) + + // Action: Scale Down (3 -> 1). + err := f.fr.UpdateShardCount(1) + require.NoError(t, err, "Scaling down should succeed") + f.synchronizeControlPlane() + + // Verification. + currentShards := f.fr.Shards() + // Total shards should still be 3 (1 active, 2 draining). + require.Len(t, currentShards, 3) + + assert.True(t, f.getShardByID(activeShardID).IsActive(), "Shard 0 should remain active") + assert.False(t, f.getShardByID(drainingShard1ID).IsActive(), "Shard 1 should be draining") + assert.False(t, f.getShardByID(drainingShard2ID).IsActive(), "Shard 2 should be draining") + + // The active shard should have its capacity re-partitioned to the full amount. + stats0 := f.getShardByID(activeShardID).Stats() + assert.Equal(t, uint64(100), stats0.TotalCapacityBytes, "Active shard should have full global capacity") + assert.Equal(t, uint64(50), stats0.PerPriorityBandStats[pHigh].CapacityBytes, + "Active shard band should have full capacity") +} + +// TestFlowRegistry_Sharding_ScaleDown_GCAndPurge verifies that a draining shard is fully garbage collected +// once it becomes empty, and crucially, that its ID is purged from all flow tracking maps (preventing memory leaks). +func TestFlowRegistry_Sharding_ScaleDown_GCAndPurge(t *testing.T) { + t.Parallel() + // Start with 2, scale to 1. + opts := fixtureOptions{ + customConfig: newShardingTestConfig(), + initialShardCount: 2, + useFakeClock: true, + } + f := newRegistryTestFixture(t, opts) + const flowID = "scale-down-gc" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + + require.NoError(t, f.fr.RegisterOrUpdateFlow(specHigh), "Initial registration should succeed") + f.synchronizeControlPlane() + + // Identify the shard destined for draining (Shard 1, based on implementation detail). + shards := f.fr.Shards() + require.Len(t, shards, 2) + drainingShardID := shards[1].ID() + + // Add an item to the draining shard. + item := f.addItem(flowID, pHigh, drainingShardID, 10) + f.synchronizeControlPlane() + f.assertStatsNow(1, 10) + + // Action 1: Scale Down (2 -> 1). + require.NoError(t, f.fr.UpdateShardCount(1)) + f.synchronizeControlPlane() + + // Verify Draining State. + currentShards := f.fr.Shards() + require.Len(t, currentShards, 2, "Shard count should be 2 (1 active, 1 draining)") + require.False(t, f.getShardByID(drainingShardID).IsActive(), "Shard 1 should be draining") + + // Action 2: Empty the draining shard. This triggers the ShardBecameDrained signal. + f.removeItem(flowID, pHigh, drainingShardID, item) + f.synchronizeControlPlane() + + // Verification: Shard should be completely garbage collected. + finalShards := f.fr.Shards() + require.Len(t, finalShards, 1, "Drained shard was not GC'd") + assert.NotEqual(t, drainingShardID, finalShards[0].ID(), "The wrong shard was garbage collected") + + // Crucial Check: Ensure the decommissioned shard ID is purged from flow state (prevents memory leak). + f.fr.mu.Lock() + flowState, ok := f.fr.flowStates[flowID] + require.True(t, ok, "Flow should still exist") + assert.Len(t, flowState.activeQueueEmptyOnShards, 1, "Flow state tracking map size incorrect after purge") + assert.NotContains(t, flowState.activeQueueEmptyOnShards, drainingShardID, + "Old shard ID still present in tracking map") + f.fr.mu.Unlock() +} + +func TestFlowRegistry_Sharding_ErrorHandling(t *testing.T) { + t.Parallel() + opts := fixtureOptions{initialShardCount: 1, useFakeClock: true} + f := newRegistryTestFixture(t, opts) + + // Test invalid count (0). + err := f.fr.UpdateShardCount(0) + assert.Error(t, err, "Updating shard count to 0 should fail") + assert.ErrorIs(t, err, contracts.ErrInvalidShardCount) + + // Test no-op update. + err = f.fr.UpdateShardCount(1) + assert.NoError(t, err, "Updating shard count to the same value should be a successful no-op") +} + +// --- Test Functions: Statistics and Concurrency --- + +func TestFlowRegistry_StatsAggregation(t *testing.T) { + t.Parallel() + config := &Config{ + MaxBytes: 10000, + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High", MaxBytes: 5000}, + {Priority: pLow, PriorityName: "Low", MaxBytes: 3000}, + }, + } + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 2, customConfig: config, useFakeClock: true}) + // Get shard IDs for targeted additions. + shardIDs := f.getActiveShardIDs() + require.Len(t, shardIDs, 2) + s0ID, s1ID := shardIDs[0], shardIDs[1] + + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: "flow1-high", Priority: pHigh}), + "Registering flow1-high should succeed") + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: "flow2-low", Priority: pLow}), + "Registering flow2-low should succeed") + + // Add items distributed across flows, priorities, and shards. + f.addItem("flow1-high", pHigh, s0ID, 100) + f.addItem("flow1-high", pHigh, s1ID, 50) + f.addItem("flow2-low", pLow, s0ID, 150) + f.addItem("flow2-low", pLow, s0ID, 200) // Two items on the same flow/shard + + f.synchronizeControlPlane() + + // Verify Global Stats. + f.assertStatsNow(4, 500) // 100+50+150+200 = 500 + stats := f.fr.Stats() + assert.Equal(t, uint64(10000), stats.TotalCapacityBytes, "Global TotalCapacityBytes mismatch") + + // Verify Per-Priority Stats. + statsHigh := stats.PerPriorityBandStats[pHigh] + assert.Equal(t, uint64(5000), statsHigh.CapacityBytes, "High Priority CapacityBytes mismatch") + assert.Equal(t, uint64(2), statsHigh.Len, "High Priority Len mismatch") + assert.Equal(t, uint64(150), statsHigh.ByteSize, "High Priority ByteSize mismatch") + + statsLow := stats.PerPriorityBandStats[pLow] + assert.Equal(t, uint64(3000), statsLow.CapacityBytes, "Low Priority CapacityBytes mismatch") + assert.Equal(t, uint64(2), statsLow.Len, "Low Priority Len mismatch") + assert.Equal(t, uint64(350), statsLow.ByteSize, "Low Priority ByteSize mismatch") + + // Verify Shard Stats. We use the index here specifically to verify the partitioning logic. + // ShardStats() order matches the internal activeShards slice order. + shardStats := f.fr.ShardStats() + require.Len(t, shardStats, 2, "Expected 2 shard stats entries") + + // Find the stats corresponding to the IDs (order is deterministic). + var statsS0, statsS1 contracts.ShardStats + allShards := f.fr.Shards() + if allShards[0].ID() == s0ID { + statsS0, statsS1 = shardStats[0], shardStats[1] + } else { + // This branch should technically not happen given the deterministic initialization, but added for robustness. + statsS1, statsS0 = shardStats[0], shardStats[1] + } + + // Shard 0 Stats (3 items, 450 bytes) + assert.Equal(t, uint64(3), statsS0.TotalLen, "Shard 0 TotalLen mismatch") + assert.Equal(t, uint64(450), statsS0.TotalByteSize, "Shard 0 TotalByteSize mismatch") + assert.Equal(t, uint64(5000), statsS0.TotalCapacityBytes, "Shard 0 TotalCapacityBytes mismatch") // 10000 / 2 + assert.Equal(t, uint64(1), statsS0.PerPriorityBandStats[pHigh].Len, "Shard 0 High Priority Len mismatch") + assert.Equal(t, uint64(2500), statsS0.PerPriorityBandStats[pHigh].CapacityBytes, + "Shard 0 High Priority CapacityBytes mismatch") // 5000 / 2 + assert.Equal(t, uint64(2), statsS0.PerPriorityBandStats[pLow].Len, "Shard 0 Low Priority Len mismatch") + assert.Equal(t, uint64(1500), statsS0.PerPriorityBandStats[pLow].CapacityBytes, + "Shard 0 Low Priority CapacityBytes mismatch") // 3000 / 2 + + // Shard 1 Stats (1 item, 50 bytes) + assert.Equal(t, uint64(1), statsS1.TotalLen, "Shard 1 TotalLen mismatch") + assert.Equal(t, uint64(50), statsS1.TotalByteSize, "Shard 1 TotalByteSize mismatch") +} + +// TestFlowRegistry_Backpressure verifies that the data path (e.g., `addItem`) blocks if the control plane event channel +// is full. This is critical for ensuring exactly-once event delivery required by the GC system. +func TestFlowRegistry_Backpressure(t *testing.T) { + t.Parallel() + + // Configure a minimal buffer size. + config := &Config{ + EventChannelBufferSize: 1, // Set buffer to 1 + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High"}, + }, + } + // Use `RealClock` (`useFakeClock: false`) as we are testing blocking behavior over time. + f := newRegistryTestFixture(t, fixtureOptions{initialShardCount: 1, customConfig: config, useFakeClock: false}) + + const flow1, flow2, flow3 = "f1", "f2", "f3" + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: flow1, Priority: pHigh}), "Registering flow1 should succeed") + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: flow2, Priority: pHigh}), "Registering flow2 should succeed") + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: flow3, Priority: pHigh}), "Registering flow3 should succeed") + f.synchronizeControlPlane() + + // Get the shard and queue instances *before* pausing the control plane. + shard := f.getShardByID(f.getActiveShardIDs()[0]) + mq1, err := shard.ManagedQueue(flow1, pHigh) + require.NoError(t, err) + mq2, err := shard.ManagedQueue(flow2, pHigh) + require.NoError(t, err) + mq3, err := shard.ManagedQueue(flow3, pHigh) + require.NoError(t, err) + + // 1. Pause the control plane event loop by acquiring the main lock. + // This prevents the loop from consuming events from the channel. + f.fr.mu.Lock() + + // 2. Fill the buffer. + // The first `addItem` generates a BecameNonEmpty event, filling the buffer (size 1). + item1 := mocks.NewMockQueueItemAccessor(10, "req-1", flow1) + require.NoError(t, mq1.Add(item1)) + + // The second `addItem` generates another event. The `addItem` call itself returns quickly, but the underlying + // `managedQueue`'s send to the channel will eventually fill the remaining space or block. + item2 := mocks.NewMockQueueItemAccessor(10, "req-2", flow2) + require.NoError(t, mq2.Add(item2)) + + // Wait briefly to ensure the internal atomic operations and channel sends have occurred. + time.Sleep(20 * time.Millisecond) + + // 3. Attempt an operation that generates a third event. This MUST block. + operationCompleted := make(chan struct{}) + go func() { + // This call will eventually block inside `propagateStatsDelta` when trying to send the BecameNonEmpty event because + // the channel is full and the consumer (event loop) is paused. + item3 := mocks.NewMockQueueItemAccessor(10, "req-3", flow3) + err := mq3.Add(item3) + assert.NoError(t, err, "Add in goroutine failed") + close(operationCompleted) // Signal completion if it unblocks. + }() + + // Verify that it blocks (does not complete within a short duration). + select { + case <-operationCompleted: + f.fr.mu.Unlock() // Ensure unlock even on failure + t.Fatal("addItem did not block when the event channel was full. Backpressure failed.") + case <-time.After(100 * time.Millisecond): + // Success: The operation is blocked as expected. + } + + // 4. Unpause the control plane. + f.fr.mu.Unlock() + + // 5. Verify the blocked operation completes. + select { + case <-operationCompleted: + // Success: The operation unblocked after the control plane resumed. + case <-time.After(syncTimeout): + t.Fatal("addItem remained blocked after the control plane resumed.") + } + + // 6. Verify system consistency. + f.synchronizeControlPlane() + f.assertStatsNow(3, 30) +} + +// TestFlowRegistry_ConcurrencyStress performs concurrent administrative operations (registration, scaling) and data +// path operations (enqueue) to verify thread safety and statistical consistency. +func TestFlowRegistry_ConcurrencyStress(t *testing.T) { + t.Parallel() + const initialShards = 2 + f := newRegistryTestFixture(t, fixtureOptions{ + initialShardCount: initialShards, + useFakeClock: false, // Use `RealClock` for stress testing concurrency primitives. + customConfig: &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High"}, + {Priority: pMed, PriorityName: "Medium"}, + {Priority: pLow, PriorityName: "Low"}, + }, + }, + }) + + const numAdminRoutines = 10 + const adminOpsPerRoutine = 50 + const numDataRoutines = 20 + const dataOpsPerRoutine = 200 + + var wg sync.WaitGroup + + // 1. Concurrent Administrative Operations (Registration/Updates). + wg.Add(numAdminRoutines) + for i := range numAdminRoutines { + go func(writerID int) { + defer wg.Done() + for j := range adminOpsPerRoutine { + // Cycle through a few flow IDs to generate updates, not just new registrations. + flowID := fmt.Sprintf("flow-admin-%d", (writerID+j)%5) + // Alternate priorities to trigger draining/reactivation logic concurrently. + priority := pHigh + if j%2 == 0 { + priority = pLow + } + spec := types.FlowSpecification{ID: flowID, Priority: priority} + err := f.fr.RegisterOrUpdateFlow(spec) + assert.NoError(t, err, "Concurrent RegisterOrUpdateFlow failed") + } + }(i) + } + + // 2. Concurrent Shard Scaling. + wg.Add(1) + go func() { + defer wg.Done() + for j := range 20 { + // Cycle shard count between 2, 3, and 4. + newCount := uint(initialShards + (j % 3)) + err := f.fr.UpdateShardCount(newCount) + assert.NoError(t, err, "Concurrent UpdateShardCount failed") + // Small sleep to allow other operations to proceed and increase contention. + time.Sleep(1 * time.Millisecond) + } + }() + + // 3. Concurrent Data Path Operations (Enqueue). + const dataFlowID = "data-path-flow" + require.NoError(t, f.fr.RegisterOrUpdateFlow(types.FlowSpecification{ID: dataFlowID, Priority: pMed})) + + wg.Add(numDataRoutines) + for i := range numDataRoutines { + go func(routineID int) { + defer wg.Done() + for j := range dataOpsPerRoutine { + // Attempt to enqueue on an available shard. + shards := f.fr.Shards() + if len(shards) == 0 { + continue + } + // Select a shard based on the iteration count. + shard := shards[j%len(shards)] + + // We must handle potential errors gracefully, as the shard or the flow might be draining due to concurrent + // administrative operations (scaling or priority updates). + mq, err := shard.ActiveManagedQueue(dataFlowID) + if err != nil { + // Flow might not exist on this shard yet, or the shard might be draining. + continue + } + item := mocks.NewMockQueueItemAccessor(10, fmt.Sprintf("req-%d-%d", routineID, j), dataFlowID) + // Add might also fail if the queue transitions to draining just before the call. + _ = mq.Add(item) + } + }(i) + } + + wg.Wait() + + // 4. Final Consistency Check. + f.synchronizeControlPlane() + + globalStats := f.fr.Stats() + shardStats := f.fr.ShardStats() + + // The critical check: Ensure the sum of the parts (shards) equals the whole (global stats). + var aggregatedLen, aggregatedBytes uint64 + for _, s := range shardStats { + aggregatedLen += s.TotalLen + aggregatedBytes += s.TotalByteSize + } + + assert.Equal(t, aggregatedLen, globalStats.TotalLen, "Global length should match the sum of shard lengths") + assert.Equal(t, aggregatedBytes, globalStats.TotalByteSize, + "Global byte size should match the sum of shard byte sizes") +} diff --git a/pkg/epp/flowcontrol/registry/shard.go b/pkg/epp/flowcontrol/registry/shard.go index 1e894f07a..aedef53cc 100644 --- a/pkg/epp/flowcontrol/registry/shard.go +++ b/pkg/epp/flowcontrol/registry/shard.go @@ -27,44 +27,65 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" - intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// registryShard implements the `contracts.RegistryShard` interface. It represents a single, concurrent-safe slice of -// the `FlowRegistry`'s state, providing an operational view for a single `controller.FlowController` worker. +// shardCallbacks groups the callback functions that a `registryShard` uses to communicate with its parent registry. +type shardCallbacks struct { + propagateStatsDelta propagateStatsDeltaFunc + signalQueueState func(shardID string, spec types.FlowSpecification, signal queueStateSignal) + signalShardState signalShardStateFunc +} + +// registryShard implements the `contracts.RegistryShard` interface. It is the data plane counterpart to the +// `FlowRegistry`'s control plane. +// +// # Role: The Data Plane Slice +// +// It represents a single, concurrent-safe slice of the registry's total state, holding the partitioned configuration +// and the actual queue instances (`managedQueue`) for its assigned partition. It provides a read-optimized, operational +// view for a single `controller.FlowController` worker. +// +// The registryShard is deliberately kept simple regarding coordination logic; it relies on the parent `FlowRegistry` +// (the control plane) to orchestrate complex lifecycle events like registration and garbage collection. +// +// # Concurrency: `RWMutex` and Atomics // -// # Responsibilities +// The `registryShard` uses a hybrid approach to manage concurrency, balancing read performance with write safety: // -// - Holding the partitioned configuration and state (queues, policies) for its assigned shard. -// - Providing read-only access to its state for the `controller.FlowController`'s dispatch loop. -// - Aggregating statistics from its `managedQueue` instances. +// - `sync.RWMutex` (mu): Protects the shard's internal maps (`priorityBands`, `activeFlows`) during administrative +// operations (flow synchronization). This ensures that the set of available queues appears atomic to the +// `controller.FlowController` workers. All read-oriented methods acquire the read lock, allowing parallel access. // -// # Concurrency +// - Atomics (Stats): All statistics (`totalByteSize`, `totalLen`) are implemented using atomics. This allows for +// lock-free, high-performance updates on the data path hot path. // -// The `registryShard` uses a combination of an `RWMutex` and atomic operations to manage concurrency. -// - The `mu` RWMutex protects the shard's internal maps (`priorityBands`, `activeFlows`) during administrative -// operations like flow registration or updates. This ensures that the set of active or draining queues appears -// atomic to a `controller.FlowController` worker. All read-oriented methods on the shard take a read lock. -// - All statistics (`totalByteSize`, `totalLen`, etc.) are implemented as `atomic.Uint64` to allow for lock-free, -// high-performance updates from many concurrent queue operations. +// - Atomic Lifecycle (Status): The shard's lifecycle state (Active, Draining, Drained) is managed via an atomic +// `status` enum. This ensures robust, atomic state transitions and provides the mechanism for exactly-once signaling +// when the shard finishes draining (the transition from Draining to Drained). type registryShard struct { - id string - logger logr.Logger - config *Config // Holds the *partitioned* config for this shard. - isActive bool - reconcileFun parentStatsReconciler + id string + logger logr.Logger + config *Config // Holds the *partitioned* config for this shard. + + // status tracks the lifecycle state of the shard (Active, Draining, Drained). + // It is stored as an `int32` for atomic operations. + status atomic.Int32 // `componentStatus` + + // parentCallbacks provides the communication channels back to the parent registry. + parentCallbacks shardCallbacks // mu protects the shard's internal maps (`priorityBands` and `activeFlows`). mu sync.RWMutex - // priorityBands is the primary lookup table for all managed queues on this shard, organized by `priority`, then by - // `flowID`. This map contains BOTH active and draining queues. + // priorityBands is the primary lookup table for all managed queues on this shard, organized by priority, then by + // flow ID. This map contains BOTH active and draining queues. priorityBands map[uint]*priorityBand // activeFlows is a flattened map for O(1) access to the SINGLE active queue for a given logical flow ID. - // This is the critical lookup for the `Enqueue` path. If a `flowID` is not in this map, it has no active queue on - // this shard. + // This is the critical lookup for the `Enqueue` path. If a flow ID is not in this map, it has no active queue on this + // shard. activeFlows map[string]*managedQueue // orderedPriorityLevels is a cached, sorted list of `priority` levels. @@ -78,6 +99,7 @@ type registryShard struct { } // priorityBand holds all the `managedQueues` and configuration for a single priority level within a shard. +// It acts as a logical grouping for all state related to a specific priority. type priorityBand struct { // config holds the partitioned config for this specific band. config PriorityBandConfig @@ -90,28 +112,30 @@ type priorityBand struct { byteSize atomic.Uint64 len atomic.Uint64 - // Cached policy instances for this band, created at initialization. - interFlowDispatchPolicy framework.InterFlowDispatchPolicy - defaultIntraFlowDispatchPolicy framework.IntraFlowDispatchPolicy + // Cached policy instance for this band, created at initialization. + interFlowDispatchPolicy framework.InterFlowDispatchPolicy } +var _ contracts.RegistryShard = ®istryShard{} + // newShard creates a new `registryShard` instance from a partitioned configuration. func newShard( id string, partitionedConfig *Config, logger logr.Logger, - reconcileFunc parentStatsReconciler, + parentCallbacks shardCallbacks, ) (*registryShard, error) { shardLogger := logger.WithName("registry-shard").WithValues("shardID", id) s := ®istryShard{ - id: id, - logger: shardLogger, - config: partitionedConfig, - isActive: true, - reconcileFun: reconcileFunc, - priorityBands: make(map[uint]*priorityBand, len(partitionedConfig.PriorityBands)), - activeFlows: make(map[string]*managedQueue), + id: id, + logger: shardLogger, + config: partitionedConfig, + parentCallbacks: parentCallbacks, + priorityBands: make(map[uint]*priorityBand, len(partitionedConfig.PriorityBands)), + activeFlows: make(map[string]*managedQueue), } + // Initialize the shard in the Active state. + s.status.Store(int32(componentStatusActive)) for _, bandConfig := range partitionedConfig.PriorityBands { interPolicy, err := inter.NewPolicyFromName(bandConfig.InterFlowDispatchPolicy) @@ -120,70 +144,47 @@ func newShard( bandConfig.InterFlowDispatchPolicy, bandConfig.Priority, err) } - intraPolicy, err := intra.NewPolicyFromName(bandConfig.IntraFlowDispatchPolicy) - if err != nil { - return nil, fmt.Errorf("failed to create intra-flow policy %q for priority band %d: %w", - bandConfig.IntraFlowDispatchPolicy, bandConfig.Priority, err) - } - + // The intra-flow policy is instantiated on-demand by the `FlowRegistry`, not cached here. s.priorityBands[bandConfig.Priority] = &priorityBand{ - config: bandConfig, - queues: make(map[string]*managedQueue), - interFlowDispatchPolicy: interPolicy, - defaultIntraFlowDispatchPolicy: intraPolicy, + config: bandConfig, + queues: make(map[string]*managedQueue), + interFlowDispatchPolicy: interPolicy, } s.orderedPriorityLevels = append(s.orderedPriorityLevels, bandConfig.Priority) } - // Sort the priority levels to ensure deterministic iteration order. + // Sort priorities in ascending order (0 is highest priority). slices.Sort(s.orderedPriorityLevels) s.logger.V(logging.DEFAULT).Info("Registry shard initialized successfully", "priorityBandCount", len(s.priorityBands), "orderedPriorities", s.orderedPriorityLevels) return s, nil } -// reconcileStats is the single point of entry for all statistics changes within the shard. It updates the relevant -// band's stats, the shard's total stats, and propagates the delta to the parent registry. -func (s *registryShard) reconcileStats(priority uint, lenDelta, byteSizeDelta int64) { - s.totalLen.Add(uint64(lenDelta)) - s.totalByteSize.Add(uint64(byteSizeDelta)) - - if band, ok := s.priorityBands[priority]; ok { - band.len.Add(uint64(lenDelta)) - band.byteSize.Add(uint64(byteSizeDelta)) - } - - s.logger.V(logging.TRACE).Info("Reconciled shard stats", "priority", priority, - "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) - - if s.reconcileFun != nil { - s.reconcileFun(lenDelta, byteSizeDelta) - } -} - // ID returns the unique identifier for this shard. func (s *registryShard) ID() string { return s.id } // IsActive returns true if the shard is active and accepting new requests. +// This is used by the `controller.FlowController` to determine if it should use this shard for new enqueue operations. func (s *registryShard) IsActive() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.isActive + return componentStatus(s.status.Load()) == componentStatusActive } -// ActiveManagedQueue returns the currently active `ManagedQueue` for a given flow. +// ActiveManagedQueue returns the currently active `contracts.ManagedQueue` for a given flow. func (s *registryShard) ActiveManagedQueue(flowID string) (contracts.ManagedQueue, error) { s.mu.RLock() defer s.mu.RUnlock() mq, ok := s.activeFlows[flowID] if !ok { + // We do not check the shard's status here. Even if the shard is draining, the specific queue might still be active + // if the flow configuration hasn't changed. The queue itself will reject the `Add` if it is also draining. return nil, fmt.Errorf("failed to get active queue for flow %q: %w", flowID, contracts.ErrFlowInstanceNotFound) } return mq, nil } -// ManagedQueue retrieves a specific (potentially draining) `ManagedQueue` instance from this shard. +// ManagedQueue retrieves a specific (potentially draining or drained) `contracts.ManagedQueue` instance from this +// shard. func (s *registryShard) ManagedQueue(flowID string, priority uint) (contracts.ManagedQueue, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -234,7 +235,7 @@ func (s *registryShard) InterFlowDispatchPolicy(priority uint) (framework.InterF return band.interFlowDispatchPolicy, nil } -// PriorityBandAccessor retrieves a read-only accessor for a given priority level. +// PriorityBandAccessor retrieves a read-only view for a given priority level. func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -247,8 +248,11 @@ func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityB return &priorityBandAccessor{shard: s, band: band}, nil } -// AllOrderedPriorityLevels returns all configured priority levels for this shard, sorted from highest to lowest -// priority (ascending numerical order). +// AllOrderedPriorityLevels returns a cached, sorted slice of all configured priority levels for this shard. +// The slice is sorted from highest to lowest priority (ascending numerical order). +// +// This list is cached at initialization to provide a stable, ordered view for the `controller.FlowController`'s +// dispatch loop, avoiding repeated map key iteration and sorting on the hot path. func (s *registryShard) AllOrderedPriorityLevels() []uint { // This is cached and read-only, so no lock is needed. return s.orderedPriorityLevels @@ -270,7 +274,7 @@ func (s *registryShard) Stats() contracts.ShardStats { stats.PerPriorityBandStats[priority] = contracts.PriorityBandStats{ Priority: priority, PriorityName: band.config.PriorityName, - CapacityBytes: band.config.MaxBytes, // This is the partitioned capacity + CapacityBytes: band.config.MaxBytes, // This is the partitioned capacity. ByteSize: band.byteSize.Load(), Len: band.len.Load(), } @@ -278,9 +282,182 @@ func (s *registryShard) Stats() contracts.ShardStats { return stats } -var _ contracts.RegistryShard = ®istryShard{} +// --- Internal Administrative/Lifecycle Methods (called by `FlowRegistry`) --- + +// synchronizeFlow is the internal administrative method for creating or updating a flow on this shard. +// The parent `FlowRegistry` handles validation and policy instantiation. This method implements the state machine for +// the `managedQueue` lifecycle and instantiates the underlying `framework.SafeQueue` only when necessary. +// +// This function atomically handles the following transitions under the shard's write lock: +// 1. Creation: If no queue exists, a new Active one is created. +// 2. No-Op Update: If an Active queue already exists at the target priority, nothing changes (no allocation). +// 3. Priority Change: The old Active queue is transitioned to Draining, and a new Active queue is created. +// 4. Reactivation: If a Draining/Drained queue exists at the target priority, it is transitioned back to Active (no +// allocation). +func (s *registryShard) synchronizeFlow(spec types.FlowSpecification, policy framework.IntraFlowDispatchPolicy, q framework.SafeQueue) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if an active queue already exists for this flow. + if existingActive, ok := s.activeFlows[spec.ID]; ok { + // If it's at the same priority, it's a no-op. + if existingActive.flowSpec.Priority == spec.Priority { + return + } + // It's a priority change. Mark the old one as draining. + s.logger.V(logging.TRACE).Info("Flow priority changed, marking old queue as draining.", "flowID", spec.ID, + "oldPriority", existingActive.flowSpec.Priority, "newPriority", spec.Priority) + existingActive.markAsDraining() + delete(s.activeFlows, spec.ID) + } + + // Now, either no active queue existed, or we just marked the old one as draining. + // We need to establish an active queue at the new priority. + targetBand := s.priorityBands[spec.Priority] + + // Case 1: A queue (Draining or Drained) exists at the target priority. Reactivate it. + if existingQueue, ok := targetBand.queues[spec.ID]; ok { + s.logger.V(logging.TRACE).Info("Found existing queue at target priority, reactivating.", "flowID", spec.ID, "priority", spec.Priority) + existingQueue.reactivate() + s.activeFlows[spec.ID] = existingQueue + return + } + + // Case 2: No queue exists at the target priority. Create a new one. + s.logger.V(logging.TRACE).Info("Creating new active queue for flow.", "flowID", spec.ID, "priority", spec.Priority, + "queueType", q.Name()) + + callbacks := managedQueueCallbacks{ + propagateStatsDelta: s.propagateStatsDelta, + signalQueueState: func(spec types.FlowSpecification, signal queueStateSignal) { + s.parentCallbacks.signalQueueState(s.id, spec, signal) + }, + } + // The provided `q` and `policy` are guaranteed to be non-nil by the caller (`FlowRegistry`). + mq := newManagedQueue(q, policy, spec, s.logger, callbacks) + targetBand.queues[spec.ID] = mq + s.activeFlows[spec.ID] = mq +} + +// garbageCollect removes a queue instance from the shard. +// This must be called under the shard's write lock. +func (s *registryShard) garbageCollect(flowID string, priority uint) { + s.mu.Lock() + defer s.mu.Unlock() + + s.logger.Info("Garbage collecting queue instance.", "flowID", flowID, "priority", priority) + + // Remove from the priority band's map, which contains all instances (Active/Draining/Drained). + if band, ok := s.priorityBands[priority]; ok { + delete(band.queues, flowID) + } + + // If this queue was the active one, also remove it from the activeFlows map. + // Note: A flow might be GC'd entirely (e.g., due to inactivity), in which case its active queue is removed here. + if activeQueue, ok := s.activeFlows[flowID]; ok { + if activeQueue.flowSpec.Priority == priority { + delete(s.activeFlows, flowID) + } + } +} + +// markAsDraining transitions the shard to a draining state. It will no longer be considered active for new work by the +// controller. This method is lock-free, relying on atomics for safe state transitions. +func (s *registryShard) markAsDraining() { + // Attempt to transition from Active to Draining atomically. + if s.status.CompareAndSwap(int32(componentStatusActive), int32(componentStatusDraining)) { + s.logger.Info("Shard marked as draining") + + // CRITICAL: Mark all constituent queues as draining as well. + // To prevent deadlocks, we collect the queues under a read lock, release it, and then perform the marking. This + // ensures we are not holding a lock when we invoke the callbacks inside `mq.markAsDraining()`. + var queuesToMark []*managedQueue + s.mu.RLock() + for _, band := range s.priorityBands { + for _, mq := range band.queues { + queuesToMark = append(queuesToMark, mq) + } + } + s.mu.RUnlock() + + for _, mq := range queuesToMark { + // This ensures that even if a specific flow hasn't changed configuration, its queue on this specific shard stops + // accepting new traffic. + mq.markAsDraining() + } + } + + // CRITICAL: Check if the shard is *already* empty at the moment it's marked as draining (or if it was already + // draining and empty). If so, we must immediately attempt the transition to Drained to ensure timely GC. This handles + // the race where the shard becomes empty just before or during being marked draining. + if s.totalLen.Load() == 0 { + // Attempt to transition from Draining to Drained atomically. + if s.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { + s.parentCallbacks.signalShardState(s, shardStateSignalBecameDrained) + } + } +} + +// updateConfig atomically replaces the shard's configuration. This is used during shard scaling events to re-partition +// capacity allocations. +func (s *registryShard) updateConfig(newConfig *Config) { + s.mu.Lock() + defer s.mu.Unlock() + + s.config = newConfig + // Update the partitioned config for each priority band as well. + for priority, band := range s.priorityBands { + newBandConfig, err := newConfig.getBandConfig(priority) + if err != nil { + // An invariant was violated: a priority exists in the shard but not in the new config. + // This should be impossible if the registry's logic is correct. + panic(fmt.Errorf("invariant violation: priority band (%d) missing in new configuration during update: %w", + priority, err)) + } + band.config = *newBandConfig + } + s.logger.Info("Shard configuration updated") +} -// --- priorityBandAccessor --- +// --- Internal Callback Methods --- + +// propagateStatsDelta is the single point of entry for all statistics changes within the shard. It updates the relevant +// band's stats, the shard's total stats, and propagates the delta to the parent registry. It also handles the shard's +// lifecycle signaling. +func (s *registryShard) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) { + // This function uses two's complement arithmetic to atomically add or subtract from the unsigned counters. + // Casting a negative `int64` to `uint64` results in its two's complement representation, which, when added, is + // equivalent to subtraction. This is a standard and efficient pattern for atomic updates on unsigned integers. + newTotalLen := s.totalLen.Add(uint64(lenDelta)) + s.totalByteSize.Add(uint64(byteSizeDelta)) + + if band, ok := s.priorityBands[priority]; ok { + band.len.Add(uint64(lenDelta)) + band.byteSize.Add(uint64(byteSizeDelta)) + } else { + // This should be impossible if the `managedQueue` calling this is correctly registered. + panic(fmt.Sprintf("invariant violation: received stats propagation for unknown priority band (%d)", priority)) + } + + s.logger.V(logging.TRACE).Info("Propagated shard stats delta", "priority", priority, + "lenDelta", lenDelta, "byteSizeDelta", byteSizeDelta) + + s.parentCallbacks.propagateStatsDelta(priority, lenDelta, byteSizeDelta) + + // --- State Machine Signaling Logic --- + + // Check for Draining -> Drained transition (Exactly-Once). + // This must happen if the total length just hit zero. + if newTotalLen == 0 { + // Attempt to transition from Draining to Drained atomically. + // This acts as the exactly-once latch. If it succeeds, we are the single goroutine responsible for signaling. + if s.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { + s.parentCallbacks.signalShardState(s, shardStateSignalBecameDrained) + } + } +} + +// --- `priorityBandAccessor` --- // priorityBandAccessor implements `framework.PriorityBandAccessor`. It provides a read-only, concurrent-safe view of a // single priority band within a shard. @@ -289,15 +466,13 @@ type priorityBandAccessor struct { band *priorityBand } +var _ framework.PriorityBandAccessor = &priorityBandAccessor{} + // Priority returns the numerical priority level of this band. -func (a *priorityBandAccessor) Priority() uint { - return a.band.config.Priority -} +func (a *priorityBandAccessor) Priority() uint { return a.band.config.Priority } // PriorityName returns the human-readable name of this priority band. -func (a *priorityBandAccessor) PriorityName() string { - return a.band.config.PriorityName -} +func (a *priorityBandAccessor) PriorityName() string { return a.band.config.PriorityName } // FlowIDs returns a slice of all flow IDs within this priority band. func (a *priorityBandAccessor) FlowIDs() []string { @@ -336,5 +511,3 @@ func (a *priorityBandAccessor) IterateQueues(callback func(queue framework.FlowQ } } } - -var _ framework.PriorityBandAccessor = &priorityBandAccessor{} diff --git a/pkg/epp/flowcontrol/registry/shard_test.go b/pkg/epp/flowcontrol/registry/shard_test.go index fe7f96e77..d629bd462 100644 --- a/pkg/epp/flowcontrol/registry/shard_test.go +++ b/pkg/epp/flowcontrol/registry/shard_test.go @@ -18,7 +18,9 @@ package registry import ( "errors" + "fmt" "sort" + "sync" "testing" "github.com/go-logr/logr" @@ -37,319 +39,760 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" ) -// shardTestFixture holds the components needed for a `registryShard` test. -type shardTestFixture struct { - config *Config - shard *registryShard +// mockShardSignalRecorder is a thread-safe helper for recording shard state signals. +type mockShardSignalRecorder struct { + mu sync.Mutex + signals []shardStateSignal } -// setupTestShard creates a new test fixture for testing the `registryShard`. -func setupTestShard(t *testing.T) *shardTestFixture { +func (r *mockShardSignalRecorder) signal(shard *registryShard, signal shardStateSignal) { + r.mu.Lock() + defer r.mu.Unlock() + r.signals = append(r.signals, signal) +} + +func (r *mockShardSignalRecorder) getSignals() []shardStateSignal { + r.mu.Lock() + defer r.mu.Unlock() + // Return a copy to prevent data races if the caller iterates over the slice while new signals are concurrently added + // by the system under test. + signalsCopy := make([]shardStateSignal, len(r.signals)) + copy(signalsCopy, r.signals) + return signalsCopy +} + +// shardTestHarness holds the components needed for a `registryShard` test. +type shardTestHarness struct { + t *testing.T + config *Config + shard *registryShard + shardSignaler *mockShardSignalRecorder + statsPropagator *mockStatsPropagator +} + +// newShardTestHarness creates a new test harness for testing the `registryShard`. +func newShardTestHarness(t *testing.T) *shardTestHarness { t.Helper() config := &Config{ PriorityBands: []PriorityBandConfig{ - { - Priority: 10, - PriorityName: "High", - }, - { - Priority: 20, - PriorityName: "Low", - }, + {Priority: pHigh, PriorityName: "High"}, + {Priority: pLow, PriorityName: "Low"}, }, } // Apply defaults to the master config first, as the parent registry would. err := config.validateAndApplyDefaults() - require.NoError(t, err, "Setup: validating and defaulting config should not fail") + require.NoError(t, err, "Test setup: validating and defaulting config should not fail") + shardSignaler := &mockShardSignalRecorder{} + statsPropagator := &mockStatsPropagator{} // The parent registry would partition the config. For a single shard test, we can use the defaulted one directly. - shard, err := newShard("test-shard-1", config, logr.Discard(), nil) - require.NoError(t, err, "Setup: newShard should not return an error") - require.NotNil(t, shard, "Setup: newShard should return a non-nil shard") - - return &shardTestFixture{ - config: config, - shard: shard, + callbacks := shardCallbacks{ + propagateStatsDelta: statsPropagator.propagate, + signalQueueState: func(string, types.FlowSpecification, queueStateSignal) {}, + signalShardState: shardSignaler.signal, + } + shard, err := newShard("test-shard-1", config, logr.Discard(), callbacks) + require.NoError(t, err, "Test setup: newShard should not return an error") + require.NotNil(t, shard, "Test setup: newShard should return a non-nil shard") + + return &shardTestHarness{ + t: t, + config: config, + shard: shard, + shardSignaler: shardSignaler, + statsPropagator: statsPropagator, } } -// _reconcileFlow_testOnly is a test helper that simulates the future admin logic for adding or updating a flow. -// It creates a `managedQueue` and correctly populates the `priorityBands` and `activeFlows` maps. -// This helper is intended to be replaced by the real `reconcileFlow` method in a future PR. -func (s *registryShard) _reconcileFlow_testOnly( - t *testing.T, - flowSpec types.FlowSpecification, - isActive bool, -) *managedQueue { - t.Helper() - - band, ok := s.priorityBands[flowSpec.Priority] - require.True(t, ok, "Setup: priority band %d should exist", flowSpec.Priority) +// synchronizeFlow is a test helper that simulates the parent registry's logic for instantiating plugins and reconciling +// a flow. +func (h *shardTestHarness) synchronizeFlow(spec types.FlowSpecification) { + h.t.Helper() - lq, err := queue.NewQueueFromName(listqueue.ListQueueName, nil) - require.NoError(t, err, "Setup: creating a real listqueue should not fail") + // Look up the configuration from the master config, as the `FlowRegistry` would. + // We use the internal optimized lookup method populated during harness creation. + bandConfig, err := h.config.getBandConfig(spec.Priority) + // If `getBandConfig` fails, it means the priority band doesn't exist in the config. + require.NoError(h.t, err, "Test setup: priority band %d should exist in master config", spec.Priority) - mq := newManagedQueue( - lq, - band.defaultIntraFlowDispatchPolicy, - flowSpec, - logr.Discard(), - func(lenDelta, byteSizeDelta int64) { s.reconcileStats(flowSpec.Priority, lenDelta, byteSizeDelta) }, - ) - require.NotNil(t, mq, "Setup: newManagedQueue should not return nil") + policy, err := intra.NewPolicyFromName(bandConfig.IntraFlowDispatchPolicy) + require.NoError(h.t, err, "Test setup: failed to create intra-flow policy") - band.queues[flowSpec.ID] = mq - if isActive { - s.activeFlows[flowSpec.ID] = mq - } + q, err := queue.NewQueueFromName(bandConfig.Queue, policy.Comparator()) + require.NoError(h.t, err, "Test setup: failed to create queue") - return mq + h.shard.synchronizeFlow(spec, policy, q) } -func TestNewShard(t *testing.T) { - t.Parallel() - f := setupTestShard(t) - - assert.Equal(t, "test-shard-1", f.shard.ID(), "ID should be set correctly") - assert.True(t, f.shard.IsActive(), "A new shard should be active") - require.Len(t, f.shard.priorityBands, 2, "Should have 2 priority bands") - - // Check that priority levels are sorted correctly - assert.Equal(t, []uint{10, 20}, f.shard.AllOrderedPriorityLevels(), "Priority levels should be ordered") - - // Check band 10 - band10, ok := f.shard.priorityBands[10] - require.True(t, ok, "Priority band 10 should exist") - assert.Equal(t, uint(10), band10.config.Priority, "Band 10 should have correct priority") - assert.Equal(t, "High", band10.config.PriorityName, "Band 10 should have correct name") - assert.NotNil(t, band10.interFlowDispatchPolicy, "Inter-flow policy for band 10 should be instantiated") - assert.NotNil(t, band10.defaultIntraFlowDispatchPolicy, - "Default intra-flow policy for band 10 should be instantiated") - assert.Equal(t, besthead.BestHeadPolicyName, band10.interFlowDispatchPolicy.Name(), - "Correct default inter-flow policy should be used") - assert.Equal(t, fcfs.FCFSPolicyName, band10.defaultIntraFlowDispatchPolicy.Name(), - "Correct default intra-flow policy should be used") - - // Check band 20 - band20, ok := f.shard.priorityBands[20] - require.True(t, ok, "Priority band 20 should exist") - assert.Equal(t, uint(20), band20.config.Priority, "Band 20 should have correct priority") - assert.Equal(t, "Low", band20.config.PriorityName, "Band 20 should have correct name") - assert.NotNil(t, band20.interFlowDispatchPolicy, "Inter-flow policy for band 20 should be instantiated") - assert.NotNil(t, band20.defaultIntraFlowDispatchPolicy, - "Default intra-flow policy for band 20 should be instantiated") +// addItem is a test helper to add an item to a specific flow on the shard. +func (h *shardTestHarness) addItem(flowID string, priority uint, size uint64) types.QueueItemAccessor { + h.t.Helper() + mq, err := h.shard.ManagedQueue(flowID, priority) + require.NoError(h.t, err, "Helper addItem failed to get queue for flow %q at priority %d", flowID, priority) + item := mocks.NewMockQueueItemAccessor(size, "req", flowID) + require.NoError(h.t, mq.Add(item), "Helper addItem failed to add item to queue") + return item } -// TestNewShard_ErrorPaths modifies global plugin registries, so it cannot be run in parallel with other tests. -func TestNewShard_ErrorPaths(t *testing.T) { - baseConfig := &Config{ - PriorityBands: []PriorityBandConfig{{ - Priority: 10, - PriorityName: "High", - IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, - InterFlowDispatchPolicy: besthead.BestHeadPolicyName, - Queue: listqueue.ListQueueName, - }}, - } - require.NoError(t, baseConfig.validateAndApplyDefaults(), "Setup: base config should be valid") +// removeItem is a test helper to remove an item from a specific flow's queue on the shard. +func (h *shardTestHarness) removeItem(flowID string, priority uint, item types.QueueItemAccessor) { + h.t.Helper() + mq, err := h.shard.ManagedQueue(flowID, priority) + require.NoError(h.t, err, "Helper removeItem failed to get queue for flow %q at priority %d", flowID, priority) + _, err = mq.Remove(item.Handle()) + require.NoError(h.t, err, "Helper removeItem failed to remove item from queue") +} - t.Run("Invalid InterFlow Policy", func(t *testing.T) { - // Register a mock policy that always fails to instantiate - failingPolicyName := inter.RegisteredPolicyName("failing-inter-policy") - inter.MustRegisterPolicy(failingPolicyName, func() (framework.InterFlowDispatchPolicy, error) { - return nil, errors.New("inter-flow instantiation failed") - }) +func TestShard(t *testing.T) { + t.Parallel() - badConfig := *baseConfig - badConfig.PriorityBands[0].InterFlowDispatchPolicy = failingPolicyName + t.Run("New_InitialState", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + + assert.Equal(t, "test-shard-1", h.shard.ID(), "ID should be set correctly") + assert.True(t, h.shard.IsActive(), "A new shard should be active") + require.Len(t, h.shard.priorityBands, 2, "Should have 2 priority bands") + + // Check that priority levels are sorted correctly. + assert.Equal(t, []uint{pHigh, pLow}, h.shard.AllOrderedPriorityLevels(), "Priority levels should be ordered") + + // Check band `pHigh`. + bandHigh, ok := h.shard.priorityBands[pHigh] + require.True(t, ok, "High priority band should exist") + assert.Equal(t, pHigh, bandHigh.config.Priority, "High priority band should have correct priority") + assert.Equal(t, "High", bandHigh.config.PriorityName, "High priority band should have correct name") + assert.NotNil(t, bandHigh.interFlowDispatchPolicy, + "Inter-flow policy for high priority band should be instantiated") + assert.Equal(t, besthead.BestHeadPolicyName, bandHigh.interFlowDispatchPolicy.Name(), + "Correct default inter-flow policy should be used") + assert.Equal(t, string(fcfs.FCFSPolicyName), string(bandHigh.config.IntraFlowDispatchPolicy), + "Correct default intra-flow policy should be used") + + // Check band `pLow`. + bandLow, ok := h.shard.priorityBands[pLow] + require.True(t, ok, "Low priority band should exist") + assert.Equal(t, pLow, bandLow.config.Priority, "Low priority band should have correct priority") + assert.Equal(t, "Low", bandLow.config.PriorityName, "Low priority band should have correct name") + assert.NotNil(t, bandLow.interFlowDispatchPolicy, "Inter-flow policy for low priority band should be instantiated") + }) - _, err := newShard("test", &badConfig, logr.Discard(), nil) - require.Error(t, err, "newShard should fail with an invalid inter-flow policy") + t.Run("Stats_Aggregation", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + + // Add a queue and some items to test stats aggregation + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + h.addItem("flow1", pHigh, 100) + h.addItem("flow1", pHigh, 50) + + stats := h.shard.Stats() + + // Check shard-level stats + assert.Equal(t, uint64(2), stats.TotalLen, "Total length should be 2") + assert.Equal(t, uint64(150), stats.TotalByteSize, "Total byte size should be 150") + + // Check per-band stats + require.Len(t, stats.PerPriorityBandStats, 2, "Should have stats for 2 bands") + bandHighStats := stats.PerPriorityBandStats[pHigh] + assert.Equal(t, pHigh, bandHighStats.Priority, "High priority band stats should have correct priority") + assert.Equal(t, uint64(2), bandHighStats.Len, "High priority band length should be 2") + assert.Equal(t, uint64(150), bandHighStats.ByteSize, "High priority band byte size should be 150") + + bandLowStats := stats.PerPriorityBandStats[pLow] + assert.Equal(t, pLow, bandLowStats.Priority, "Low priority band stats should have correct priority") + assert.Zero(t, bandLowStats.Len, "Low priority band length should be 0") + assert.Zero(t, bandLowStats.ByteSize, "Low priority band byte size should be 0") }) - t.Run("Invalid IntraFlow Policy", func(t *testing.T) { - // Register a mock policy that always fails to instantiate - failingPolicyName := intra.RegisteredPolicyName("failing-intra-policy") - intra.MustRegisterPolicy(failingPolicyName, func() (framework.IntraFlowDispatchPolicy, error) { - return nil, errors.New("intra-flow instantiation failed") - }) + t.Run("Accessors", func(t *testing.T) { + t.Parallel() - badConfig := *baseConfig - badConfig.PriorityBands[0].IntraFlowDispatchPolicy = failingPolicyName + t.Run("Scenarios", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + + flowID := "test-flow" + + // Setup state with one active and one draining queue for the same flow. + h.synchronizeFlow(types.FlowSpecification{ID: flowID, Priority: pHigh}) + h.addItem(flowID, pHigh, 1) // Add item so it doesn't immediately become Drained. + h.synchronizeFlow(types.FlowSpecification{ID: flowID, Priority: pLow}) + + // The second reconcile call makes the `pHigh` queue draining. + activeQueue, err := h.shard.ActiveManagedQueue(flowID) + require.NoError(t, err) + drainingQueue, err := h.shard.ManagedQueue(flowID, pHigh) + require.NoError(t, err) + + t.Run("ActiveManagedQueue_ReturnsCorrectQueue", func(t *testing.T) { + t.Parallel() + retrievedActiveQueue, err := h.shard.ActiveManagedQueue(flowID) + require.NoError(t, err, "ActiveManagedQueue should not error for an existing flow") + assert.Same(t, activeQueue, retrievedActiveQueue, "Should return the correct active queue") + assert.Equal(t, pLow, retrievedActiveQueue.FlowQueueAccessor().FlowSpec().Priority, + "Active queue should have the correct priority") + + _, err = h.shard.ActiveManagedQueue("non-existent-flow") + require.Error(t, err, "ActiveManagedQueue should error for a non-existent flow") + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") + }) - _, err := newShard("test", &badConfig, logr.Discard(), nil) - require.Error(t, err, "newShard should fail with an invalid intra-flow policy") - }) -} + t.Run("ManagedQueue_ReturnsDrainingQueue", func(t *testing.T) { + t.Parallel() + retrievedDrainingQueue, err := h.shard.ManagedQueue(flowID, pHigh) + require.NoError(t, err, "ManagedQueue should not error for a draining queue") + assert.Same(t, drainingQueue, retrievedDrainingQueue, "Should return the correct draining queue") -func TestShard_Stats(t *testing.T) { - t.Parallel() - f := setupTestShard(t) + // Verify the retrieved queue is in a draining state. + mq := retrievedDrainingQueue.(*managedQueue) + status := componentStatus(mq.status.Load()) + assert.Equal(t, componentStatusDraining, status, "Retrieved queue should be in draining status") + }) - // Add a queue and some items to test stats aggregation - mq := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: 10}, true) + t.Run("IntraFlowDispatchPolicy_ReturnsCorrectPolicy", func(t *testing.T) { + t.Parallel() + retrievedActivePolicy, err := h.shard.IntraFlowDispatchPolicy(flowID, pLow) + require.NoError(t, err, "IntraFlowDispatchPolicy should not error for an active instance") + assert.Same(t, activeQueue.(*managedQueue).dispatchPolicy, retrievedActivePolicy, + "Should return the policy from the active instance") + }) - // Add items - require.NoError(t, mq.Add(mocks.NewMockQueueItemAccessor(100, "req1", "flow1")), "Adding item should not fail") - require.NoError(t, mq.Add(mocks.NewMockQueueItemAccessor(50, "req2", "flow1")), "Adding item should not fail") + t.Run("InterFlowDispatchPolicy_ReturnsCorrectPolicy", func(t *testing.T) { + t.Parallel() + retrievedInterPolicy, err := h.shard.InterFlowDispatchPolicy(pHigh) + require.NoError(t, err, "InterFlowDispatchPolicy should not error for an existing priority") + assert.Same(t, h.shard.priorityBands[pHigh].interFlowDispatchPolicy, retrievedInterPolicy, + "Should return the correct inter-flow policy") + }) + }) - stats := f.shard.Stats() + t.Run("ErrorPaths", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(types.FlowSpecification{ID: "flow-a", Priority: pHigh}) + + testCases := []struct { + name string + action func() error + expectErr error + errMessage string + }{ + { + name: "ManagedQueue_WhenPriorityNotFound_ShouldFail", + action: func() error { + _, err := h.shard.ManagedQueue("flow-a", 99) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, + errMessage: "ManagedQueue should error for a non-existent priority", + }, + { + name: "ManagedQueue_WhenFlowNotFound_ShouldFail", + action: func() error { + _, err := h.shard.ManagedQueue("non-existent-flow", pHigh) + return err + }, + expectErr: contracts.ErrFlowInstanceNotFound, + errMessage: "ManagedQueue should error for flow not in band", + }, + { + name: "IntraFlowDispatchPolicy_WhenPriorityNotFound_ShouldFail", + action: func() error { + _, err := h.shard.IntraFlowDispatchPolicy("flow-a", 99) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, + errMessage: "IntraFlowDispatchPolicy should error for non-existent priority", + }, + { + name: "IntraFlowDispatchPolicy_WhenFlowNotFound_ShouldFail", + action: func() error { + // flow-a exists at priority `pHigh`, but flow-b does not. + _, err := h.shard.IntraFlowDispatchPolicy("flow-b", pHigh) + return err + }, + expectErr: contracts.ErrFlowInstanceNotFound, + errMessage: "IntraFlowDispatchPolicy should error for a flow not in the band", + }, + { + name: "InterFlowDispatchPolicy_WhenPriorityNotFound_ShouldFail", + action: func() error { + _, err := h.shard.InterFlowDispatchPolicy(99) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, + errMessage: "InterFlowDispatchPolicy should error for non-existent priority", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := tc.action() + require.Error(t, err, tc.errMessage) + assert.ErrorIs(t, err, tc.expectErr) + }) + } + }) + }) - // Check shard-level stats - assert.Equal(t, uint64(2), stats.TotalLen, "Total length should be 2") - assert.Equal(t, uint64(150), stats.TotalByteSize, "Total byte size should be 150") + t.Run("PriorityBandAccessor_Scenarios", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) - // Check per-band stats - require.Len(t, stats.PerPriorityBandStats, 2, "Should have stats for 2 bands") - band10Stats := stats.PerPriorityBandStats[10] - assert.Equal(t, uint(10), band10Stats.Priority, "Band 10 stats should have correct priority") - assert.Equal(t, uint64(2), band10Stats.Len, "Band 10 length should be 2") - assert.Equal(t, uint64(150), band10Stats.ByteSize, "Band 10 byte size should be 150") + // Setup shard state for the tests + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pLow}) // `pHigh` is now draining + h.synchronizeFlow(types.FlowSpecification{ID: "flow2", Priority: pHigh}) - band20Stats := stats.PerPriorityBandStats[20] - assert.Equal(t, uint(20), band20Stats.Priority, "Band 20 stats should have correct priority") - assert.Zero(t, band20Stats.Len, "Band 20 length should be 0") - assert.Zero(t, band20Stats.ByteSize, "Band 20 byte size should be 0") -} + t.Run("WhenPriorityExists_ShouldSucceed", func(t *testing.T) { + t.Parallel() + accessor, err := h.shard.PriorityBandAccessor(pHigh) + require.NoError(t, err, "PriorityBandAccessor should not fail for existing priority") + require.NotNil(t, accessor, "Accessor should not be nil") + + t.Run("Properties_ShouldReturnCorrectValues", func(t *testing.T) { + t.Parallel() + assert.Equal(t, pHigh, accessor.Priority(), "Accessor should have correct priority") + assert.Equal(t, "High", accessor.PriorityName(), "Accessor should have correct priority name") + }) -func TestShard_Accessors(t *testing.T) { - t.Parallel() - f := setupTestShard(t) - - flowID := "test-flow" - activePriority := uint(10) - drainingPriority := uint(20) - - // Setup state with one active and one draining queue for the same flow - activeQueue := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ - ID: flowID, - Priority: activePriority, - }, true) - drainingQueue := f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ - ID: flowID, - Priority: drainingPriority, - }, false) - - t.Run("ActiveManagedQueue", func(t *testing.T) { - t.Parallel() - retrievedActiveQueue, err := f.shard.ActiveManagedQueue(flowID) - require.NoError(t, err, "ActiveManagedQueue should not error for an existing flow") - assert.Same(t, activeQueue, retrievedActiveQueue, "Should return the correct active queue") + t.Run("FlowIDs_ShouldReturnAllFlowsInBand", func(t *testing.T) { + t.Parallel() + flowIDs := accessor.FlowIDs() + sort.Strings(flowIDs) + assert.Equal(t, []string{"flow1", "flow2"}, flowIDs, + "Accessor should return correct flow IDs for the priority band") + }) - _, err = f.shard.ActiveManagedQueue("non-existent-flow") - require.Error(t, err, "ActiveManagedQueue should error for a non-existent flow") - assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") - }) + t.Run("Queue_ShouldReturnCorrectAccessor", func(t *testing.T) { + t.Parallel() + q := accessor.Queue("flow1") + require.NotNil(t, q, "Accessor should return queue for flow1") + assert.Equal(t, pHigh, q.FlowSpec().Priority, "Queue should have the correct priority") + assert.Nil(t, accessor.Queue("non-existent"), "Accessor should return nil for non-existent flow") + }) - t.Run("ManagedQueue", func(t *testing.T) { - t.Parallel() - retrievedDrainingQueue, err := f.shard.ManagedQueue(flowID, drainingPriority) - require.NoError(t, err, "ManagedQueue should not error for a draining queue") - assert.Same(t, drainingQueue, retrievedDrainingQueue, "Should return the correct draining queue") + t.Run("IterateQueues_ShouldVisitAllQueues", func(t *testing.T) { + t.Parallel() + var iteratedFlows []string + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) + return true + }) + sort.Strings(iteratedFlows) + assert.Equal(t, []string{"flow1", "flow2"}, iteratedFlows, "IterateQueues should visit all flows in the band") + }) - _, err = f.shard.ManagedQueue(flowID, 99) // Non-existent priority - require.Error(t, err, "ManagedQueue should error for a non-existent priority") - assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + t.Run("IterateQueues_ShouldExitEarly", func(t *testing.T) { + t.Parallel() + var iteratedFlows []string + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) + return false // Exit after first item + }) + assert.Len(t, iteratedFlows, 1, "IterateQueues should exit early if callback returns false") + }) + }) - _, err = f.shard.ManagedQueue("non-existent-flow", activePriority) - require.Error(t, err, "ManagedQueue should error for a non-existent flow in an existing priority") - assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") + t.Run("WhenPriorityDoesNotExist_ShouldFail", func(t *testing.T) { + t.Parallel() + _, err := h.shard.PriorityBandAccessor(99) + require.Error(t, err, "PriorityBandAccessor should fail for non-existent priority") + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + }) }) - t.Run("IntraFlowDispatchPolicy", func(t *testing.T) { + t.Run("Lifecycle", func(t *testing.T) { t.Parallel() - retrievedActivePolicy, err := f.shard.IntraFlowDispatchPolicy(flowID, activePriority) - require.NoError(t, err, "IntraFlowDispatchPolicy should not error for an active instance") - assert.Same(t, activeQueue.dispatchPolicy, retrievedActivePolicy, - "Should return the policy from the active instance") - - _, err = f.shard.IntraFlowDispatchPolicy("non-existent-flow", activePriority) - require.Error(t, err, "IntraFlowDispatchPolicy should error for a non-existent flow") - assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, "Error should be ErrFlowInstanceNotFound") - - _, err = f.shard.IntraFlowDispatchPolicy(flowID, 99) // Non-existent priority - require.Error(t, err, "IntraFlowDispatchPolicy should error for a non-existent priority") - assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") - }) - t.Run("InterFlowDispatchPolicy", func(t *testing.T) { - t.Parallel() - retrievedInterPolicy, err := f.shard.InterFlowDispatchPolicy(activePriority) - require.NoError(t, err, "InterFlowDispatchPolicy should not error for an existing priority") - assert.Same(t, f.shard.priorityBands[activePriority].interFlowDispatchPolicy, retrievedInterPolicy, - "Should return the correct inter-flow policy") - - _, err = f.shard.InterFlowDispatchPolicy(99) // Non-existent priority - require.Error(t, err, "InterFlowDispatchPolicy should error for a non-existent priority") - assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") - }) -} + t.Run("SynchronizeFlow", func(t *testing.T) { + t.Parallel() + const flowID = "test-flow" + specHigh := types.FlowSpecification{ID: flowID, Priority: pHigh} + specLow := types.FlowSpecification{ID: flowID, Priority: pLow} + + t.Run("ForNewFlow_ShouldCreateActiveQueue", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(specHigh) + + // Assert state. + assert.Contains(t, h.shard.activeFlows, flowID, "Flow should be in the active map") + activeQueue := h.shard.activeFlows[flowID] + assert.Equal(t, pHigh, activeQueue.flowSpec.Priority, "Active flow should have the correct priority") + assert.Equal(t, componentStatusActive, componentStatus(activeQueue.status.Load()), "Queue should be active") + + band, ok := h.shard.priorityBands[pHigh] + require.True(t, ok, "High priority band should exist") + assert.Contains(t, band.queues, flowID, "Queue should be in the correct priority band map") + }) -func TestShard_PriorityBandAccessor(t *testing.T) { - t.Parallel() - f := setupTestShard(t) + t.Run("ForExistingFlowAtSamePriority_ShouldBeNoOp", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + // Initial synchronization. + h.synchronizeFlow(specHigh) + initialQueue := h.shard.activeFlows[flowID] - // Setup shard state for the tests - p1, p2 := uint(10), uint(20) - f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: p1}, true) - f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow1", Priority: p2}, false) - f.shard._reconcileFlow_testOnly(t, types.FlowSpecification{ID: "flow2", Priority: p1}, true) + // Second synchronization with the same spec, + h.synchronizeFlow(specHigh) - t.Run("Accessor for existing priority", func(t *testing.T) { - t.Parallel() - accessor, err := f.shard.PriorityBandAccessor(p1) - require.NoError(t, err, "PriorityBandAccessor should not fail for existing priority") - require.NotNil(t, accessor, "Accessor should not be nil") + // Assert state is unchanged. + assert.Len(t, h.shard.activeFlows, 1, "There should still be only one active flow") + assert.Same(t, initialQueue, h.shard.activeFlows[flowID], "The queue instance should not have been replaced") + }) + + t.Run("ForPriorityChange_ShouldDrainOldQueue", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(specHigh) + h.addItem(flowID, pHigh, 1) // Add item so the queue doesn't immediately become Drained. + oldQueue := h.shard.activeFlows[flowID] + + // Synchronize with the new, lower priority. + h.synchronizeFlow(specLow) + + // Assert state of the new active queue. + assert.Contains(t, h.shard.activeFlows, flowID, "Flow should still be active") + newActiveQueue := h.shard.activeFlows[flowID] + assert.Equal(t, pLow, newActiveQueue.flowSpec.Priority, "Active flow should now have the low priority") + assert.Equal(t, componentStatusActive, componentStatus(newActiveQueue.status.Load()), + "New queue should be active") + + // Assert state of the old, now-draining queue. + assert.Contains(t, h.shard.priorityBands[pHigh].queues, flowID, + "Old queue should still exist in the high priority band") + assert.Equal(t, componentStatusDraining, componentStatus(oldQueue.status.Load()), + "Old queue should be marked as draining") + }) + + t.Run("ForPriorityRollback_ShouldReactivateDrainingQueue", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + // Step 1: Create active queue at `pHigh`. + h.synchronizeFlow(specHigh) + drainingQueueCandidate := h.shard.activeFlows[flowID] + h.addItem(flowID, pHigh, 1) + + // Step 2: Change priority to `pLow`, making the `pHigh` queue drain. + h.synchronizeFlow(specLow) + assert.Equal(t, componentStatusDraining, componentStatus(drainingQueueCandidate.status.Load()), + "Queue at pHigh should be draining") + + // Step 3: Roll back to `pHigh`. + h.synchronizeFlow(specHigh) + + // Assert that the original queue was reactivated + assert.Len(t, h.shard.activeFlows, 1, "There should be one active flow") + assert.Same(t, drainingQueueCandidate, h.shard.activeFlows[flowID], + "The original queue should have been reactivated") + assert.Equal(t, componentStatusActive, componentStatus(h.shard.activeFlows[flowID].status.Load()), + "The reactivated queue should be marked as active") + + // Assert that the pLow queue is now draining. + lowPriorityBand := h.shard.priorityBands[pLow] + assert.Contains(t, lowPriorityBand.queues, flowID, "The pLow queue should now be in the low priority band") + lowQueue := lowPriorityBand.queues[flowID] + // Since the low-priority queue was empty when the rollback happened, it should be immediately Drained. + assert.Equal(t, componentStatusDrained, componentStatus(lowQueue.status.Load()), + "The pLow queue should now be drained") + }) + + t.Run("GarbageCollect_ShouldRemoveDrainingQueue", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(specHigh) + h.synchronizeFlow(specLow) // `specHigh` is now draining - t.Run("Properties", func(t *testing.T) { + h.shard.garbageCollect(flowID, pHigh) + + assert.NotContains(t, h.shard.priorityBands[pHigh].queues, flowID, + "Queue should have been removed from the priority band") + }) + + t.Run("GarbageCollect_ShouldRemoveActiveQueue", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(specHigh) + + // Verify the queue is active before GC. + require.Contains(t, h.shard.activeFlows, flowID, "Test setup: queue must be in active map before GC") + + h.shard.garbageCollect(flowID, pHigh) + + assert.NotContains(t, h.shard.priorityBands[pHigh].queues, flowID, + "Queue should have been removed from the priority band map") + assert.NotContains(t, h.shard.activeFlows, flowID, "Queue should have been removed from the active flows map") + }) + }) + + t.Run("DrainingTransitions", func(t *testing.T) { t.Parallel() - assert.Equal(t, p1, accessor.Priority(), "Accessor should have correct priority") - assert.Equal(t, "High", accessor.PriorityName(), "Accessor should have correct priority name") + + t.Run("MarkAsDraining_OnNonEmptyShard_ShouldTransitionToDraining", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + h.synchronizeFlow(types.FlowSpecification{ID: "flow2", Priority: pLow}) + + // Add items to make queues non-empty. + h.addItem("flow1", pHigh, 1) + h.addItem("flow2", pLow, 1) + + // Mark the shard as draining. + h.shard.markAsDraining() + + // Assert shard status. + assert.False(t, h.shard.IsActive(), "Shard should no longer be active") + assert.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), + "Shard status should be Draining") + + // Assert status of constituent queues + mq1, _ := h.shard.ManagedQueue("flow1", pHigh) + mq2, _ := h.shard.ManagedQueue("flow2", pLow) + assert.Equal(t, componentStatusDraining, componentStatus(mq1.(*managedQueue).status.Load()), + "Queue for flow1 should be draining") + assert.Equal(t, componentStatusDraining, componentStatus(mq2.(*managedQueue).status.Load()), + "Queue for flow2 should be draining") + }) + + t.Run("MarkAsDraining_OnEmptyShard_ShouldTransitionToDrainedAndSignal", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + + // Mark the empty shard as draining. + h.shard.markAsDraining() + + // Assert status and signal. + assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), + "Shard status should be Drained") + require.Len(t, h.shardSignaler.getSignals(), 1, "A signal should have been sent") + assert.Equal(t, shardStateSignalBecameDrained, h.shardSignaler.signals[0], + "The correct signal should have been sent") + }) + + t.Run("WhenDraining_ShouldTransitionToDrainedWhenLastItemIsRemoved", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + item1 := h.addItem("flow1", pHigh, 100) + item2 := h.addItem("flow1", pHigh, 200) + + // Mark the shard as draining; it should be Draining, not Drained, since it's not empty. + h.shard.markAsDraining() + assert.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), + "Shard should be Draining while it contains items") + assert.Empty(t, h.shardSignaler.getSignals(), + "No signal should be sent while the shard is still draining with items") + + // Remove one item; it should still be Draining. + h.removeItem("flow1", pHigh, item1) + assert.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), + "Shard should remain Draining after one item is removed") + assert.Empty(t, h.shardSignaler.getSignals(), "No signal should be sent yet") + + // Remove the final item; it should now transition to Drained and signal. + h.removeItem("flow1", pHigh, item2) + assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), + "Shard should become Drained after the last item is removed") + require.Len(t, h.shardSignaler.getSignals(), 1, "A signal should have been sent upon becoming empty") + assert.Equal(t, shardStateSignalBecameDrained, h.shardSignaler.signals[0], + "The correct drained signal should have been sent") + }) }) - t.Run("FlowIDs", func(t *testing.T) { + // DrainingRaceWithConcurrentRemovals targets the race where multiple goroutines remove the last few items from a + // draining shard. It ensures the Draining -> Drained transition happens exactly once. + t.Run("DrainingRaceWithConcurrentRemovals", func(t *testing.T) { t.Parallel() - flowIDs := accessor.FlowIDs() - sort.Strings(flowIDs) - assert.Equal(t, []string{"flow1", "flow2"}, flowIDs, - "Accessor should return correct flow IDs for the priority band") + h := newShardTestHarness(t) + + numItems := 50 + var items []types.QueueItemAccessor + for i := range numItems { + // Add items to two different flows to make it more realistic. + flowID := fmt.Sprintf("flow-%d", i%2) + h.synchronizeFlow(types.FlowSpecification{ID: flowID, Priority: pHigh}) + item := h.addItem(flowID, pHigh, 10) + items = append(items, item) + } + + // Mark the shard as draining. It has items, so it will enter the Draining state. + h.shard.markAsDraining() + require.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), + "Test setup: shard must be in Draining state") + + var wg sync.WaitGroup + wg.Add(numItems) + for _, item := range items { + go func(it types.QueueItemAccessor) { + defer wg.Done() + h.removeItem(it.OriginalRequest().FlowID(), pHigh, it) + }(item) + } + wg.Wait() + + // Verification: + // No matter which goroutine removed the "last" item, the final state must be Drained, and the signal must have + // been sent exactly once. + assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), "Final state must be Drained") + assert.Len(t, h.shardSignaler.getSignals(), 1, "BecameDrained signal must be sent exactly once") + assert.Equal(t, shardStateSignalBecameDrained, h.shardSignaler.signals[0], "The correct signal must be sent") }) - t.Run("Queue", func(t *testing.T) { + // DrainingRaceWithMarkAndEmpty targets the race between markAsDraining() and the shard becoming empty via a + // concurrent item removal. It proves the atomic CAS correctly arbitrates the race. + t.Run("DrainingRaceWithMarkAndEmpty", func(t *testing.T) { t.Parallel() - q := accessor.Queue("flow1") - require.NotNil(t, q, "Accessor should return queue for flow1") - assert.Equal(t, p1, q.FlowSpec().Priority, "Queue should have the correct priority") - assert.Nil(t, accessor.Queue("non-existent"), "Accessor should return nil for non-existent flow") + h := newShardTestHarness(t) + + // Test setup: Start with a shard containing a single item. + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + item := h.addItem("flow1", pHigh, 1) + + var wg sync.WaitGroup + wg.Add(2) + + // Goroutine 1: Attempts to mark the shard as draining. + go func() { + defer wg.Done() + h.shard.markAsDraining() + }() + + // Goroutine 2: Concurrently removes the single item. + go func() { + defer wg.Done() + h.removeItem("flow1", pHigh, item) + }() + + wg.Wait() + + // Verification: + // Either markAsDraining found an empty queue, or propagateStatsDelta found a draining queue. + // In either case, the final state must be Drained and the signal must have been sent exactly once. + assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), + "Final state must be Drained regardless of race outcome") + assert.Len(t, h.shardSignaler.getSignals(), 1, "BecameDrained signal must be sent exactly once") }) + }) - t.Run("IterateQueues", func(t *testing.T) { + t.Run("UpdateConfig_UpdatesInternalState", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + h.synchronizeFlow(types.FlowSpecification{ID: "flow1", Priority: pHigh}) + + // Create a new config with different values. + newConfig := h.config.deepCopy() + newConfig.MaxBytes = 9999 + newConfig.PriorityBands[0].MaxBytes = 8888 // for priority `pHigh` + + // Update the shard's config. + h.shard.updateConfig(newConfig) + + // Assert that the shard's internal config pointer was updated. + assert.Same(t, newConfig, h.shard.config, "Shard's internal config pointer should be updated") + + // Assert that the stats reflect the new capacity. + stats := h.shard.Stats() + assert.Equal(t, uint64(9999), stats.TotalCapacityBytes, "Shard's total capacity should be updated in stats") + bandHighStats, ok := stats.PerPriorityBandStats[pHigh] + require.True(t, ok, "Stats for high priority band should exist") + assert.Equal(t, uint64(8888), bandHighStats.CapacityBytes, "Priority band's capacity should be updated in stats") + + // Assert that the config within the internal `priorityBand` struct was also updated. + bandHigh, ok := h.shard.priorityBands[pHigh] + require.True(t, ok, "Internal high priority band struct should exist") + assert.Equal(t, uint64(8888), bandHigh.config.MaxBytes, "Internal priority band's config should be updated") + }) + + // TestShard_Invariants_PanicOnCorruption tests conditions that should cause a panic, as they represent a corrupted or + // inconsistent state that cannot be recovered from. + t.Run("Invariants_PanicOnCorruption", func(t *testing.T) { + t.Parallel() + + t.Run("PropagateStatsDelta_WithUnknownPriority_ShouldPanic", func(t *testing.T) { t.Parallel() - var iteratedFlows []string - accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { - iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) - return true - }) - sort.Strings(iteratedFlows) - assert.Equal(t, []string{"flow1", "flow2"}, iteratedFlows, "IterateQueues should visit all flows in the band") + h := newShardTestHarness(t) + // Manually create a managedQueue with a priority that does not exist in the shard's config. + // This simulates a state corruption where an invalid queue instance is somehow created. + invalidPriority := uint(99) + spec := types.FlowSpecification{ID: "bad-flow", Priority: invalidPriority} + q, _ := queue.NewQueueFromName(listqueue.ListQueueName, nil) + policy, _ := intra.NewPolicyFromName(fcfs.FCFSPolicyName) + callbacks := managedQueueCallbacks{ + propagateStatsDelta: h.shard.propagateStatsDelta, // Use the real shard's propagator. + signalQueueState: func(spec types.FlowSpecification, signal queueStateSignal) {}, + } + mq := newManagedQueue(q, policy, spec, logr.Discard(), callbacks) + + // The call to propagateStatsDelta is what should panic. + // This is triggered by calling `Add()` on the manually created queue. + expectedPanicMsg := fmt.Sprintf("invariant violation: received stats propagation for unknown priority band (%d)", + invalidPriority) + assert.PanicsWithValue(t, + expectedPanicMsg, + func() { + // This call will trigger the panic inside the shard's callback. + _ = mq.Add(mocks.NewMockQueueItemAccessor(1, "req", "bad-flow")) + }, + "propagateStatsDelta must panic when called with a priority that doesn't exist on the shard", + ) }) - t.Run("IterateQueues with early exit", func(t *testing.T) { + t.Run("UpdateConfig_WithMissingPriorityBand_ShouldPanic", func(t *testing.T) { t.Parallel() - var iteratedFlows []string - accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { - iteratedFlows = append(iteratedFlows, queue.FlowSpec().ID) - return false // Exit after first item - }) - assert.Len(t, iteratedFlows, 1, "IterateQueues should exit early if callback returns false") + h := newShardTestHarness(t) // This harness has priorities `pHigh` and `pLow`. + + // Create a new config that is missing one of the shard's existing priority bands (priority `pLow` is missing). + newConfig := &Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: pHigh, PriorityName: "High-Updated"}, + }, + } + err := newConfig.validateAndApplyDefaults() + require.NoError(t, err, "Test setup: creating updated config should not fail") + + // This call should panic because the shard has a band for `pLow`, but the new config does not. + // The full error string is complex due to error wrapping. + expectedErrStr := "invariant violation: priority band (30) missing in new configuration during update: config for priority 30 not found: priority band not found" + assert.PanicsWithError(t, + expectedErrStr, + func() { + h.shard.updateConfig(newConfig) + }, + "updateConfig must panic when an existing priority band is missing from the new config", + ) }) }) +} - t.Run("Error on non-existent priority", func(t *testing.T) { - t.Parallel() - _, err := f.shard.PriorityBandAccessor(99) - require.Error(t, err, "PriorityBandAccessor should fail for non-existent priority") - assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") +// TestShard_New_ErrorPaths modifies a global plugin registry, so it cannot be run in parallel with other tests that +// might also manipulate the same global state. +func TestShard_New_ErrorPaths(t *testing.T) { + baseConfig := &Config{ + PriorityBands: []PriorityBandConfig{{ + Priority: pHigh, + PriorityName: "High", + IntraFlowDispatchPolicy: fcfs.FCFSPolicyName, + InterFlowDispatchPolicy: besthead.BestHeadPolicyName, + Queue: listqueue.ListQueueName, + }}, + } + require.NoError(t, baseConfig.validateAndApplyDefaults(), "Test setup: base config should be valid") + + t.Run("WhenInterFlowPolicyIsInvalid_ShouldFail", func(t *testing.T) { + failingPolicyName := inter.RegisteredPolicyName(fmt.Sprintf("failing-inter-policy-%s", t.Name())) + inter.MustRegisterPolicy(failingPolicyName, func() (framework.InterFlowDispatchPolicy, error) { + return nil, errors.New("inter-flow instantiation failed") + }) + + badConfig := baseConfig.deepCopy() + badConfig.PriorityBands[0].InterFlowDispatchPolicy = failingPolicyName + + _, err := newShard("test", badConfig, logr.Discard(), shardCallbacks{}) + require.Error(t, err, "newShard should fail with an invalid inter-flow policy") }) } diff --git a/pkg/epp/flowcontrol/types/errors.go b/pkg/epp/flowcontrol/types/errors.go index 4da47c179..f7dffbd4d 100644 --- a/pkg/epp/flowcontrol/types/errors.go +++ b/pkg/epp/flowcontrol/types/errors.go @@ -46,9 +46,6 @@ var ( // ErrNilRequest indicates that a nil `types.FlowControlRequest` was provided. ErrNilRequest = errors.New("FlowControlRequest cannot be nil") - // ErrFlowIDEmpty indicates that a flow ID was empty when one was required. - ErrFlowIDEmpty = errors.New("flow ID cannot be empty") - // ErrQueueAtCapacity indicates that a request could not be enqueued because queue capacity limits were met. ErrQueueAtCapacity = errors.New("queue at capacity and displacement failed to make space") )