diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go index 4ae207803..ea317ca00 100644 --- a/pkg/epp/handlers/request_test.go +++ b/pkg/epp/handlers/request_test.go @@ -21,6 +21,7 @@ import ( configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" ) diff --git a/pkg/epp/scheduling/framework/plugins/picker/common.go b/pkg/epp/scheduling/framework/plugins/picker/common.go index 4bbc300da..bf1d1b07f 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/common.go +++ b/pkg/epp/scheduling/framework/plugins/picker/common.go @@ -16,6 +16,12 @@ limitations under the License. package picker +import ( + "math/rand/v2" + "sync" + "time" +) + const ( DefaultMaxNumOfEndpoints = 1 // common default to all pickers ) @@ -24,3 +30,29 @@ const ( type pickerParameters struct { MaxNumOfEndpoints int `json:"maxNumOfEndpoints"` } + +// safeRand is a thread-safe wrapper around rand.Rand +// to ensure that random operations are safe to use in concurrent environments. +type safeRand struct { + p *sync.Pool +} + +// NewSafeRand initializes a new safeRand. +func NewSafeRand() *safeRand { + p := &sync.Pool{ + New: func() any { + seed := time.Now().UnixNano() + return rand.New(rand.NewPCG(uint64(seed), uint64(seed))) + }, + } + + return &safeRand{p: p} +} + +// Shuffle is a thread-safe method to shuffle a slice. +func (s *safeRand) Shuffle(n int, swap func(i int, j int)) { + r := s.p.Get().(*rand.Rand) + defer s.p.Put(r) + + r.Shuffle(n, swap) +} diff --git a/pkg/epp/scheduling/framework/plugins/picker/common_test.go b/pkg/epp/scheduling/framework/plugins/picker/common_test.go new file mode 100644 index 000000000..e5e03bfe4 --- /dev/null +++ b/pkg/epp/scheduling/framework/plugins/picker/common_test.go @@ -0,0 +1,253 @@ +package picker + +import ( + "context" + "fmt" + "testing" + "time" +) + +func TestNewSafeRand_NotNil(t *testing.T) { + sr := NewSafeRand() + if sr == nil { + t.Fatal("NewSafeRand returned nil") + } + if sr.p == nil { + t.Fatal("sync.Pool in safeRand is nil") + } +} + +func TestSafeRand_Shuffle_ShufflesSlice(t *testing.T) { + sr := NewSafeRand() + orig := []int{1, 2, 3, 4, 5} + shuffled := make([]int, len(orig)) + copy(shuffled, orig) + + sr.Shuffle(len(shuffled), func(i, j int) { + shuffled[i], shuffled[j] = shuffled[j], shuffled[i] + }) + + // It's possible (but unlikely) that the slice remains unchanged after shuffling. + // So we check that at least one element has moved. + same := true + for i := range orig { + if orig[i] != shuffled[i] { + same = false + break + } + } + if same { + t.Log("Warning: shuffled slice is the same as original (possible but unlikely)") + } +} + +func TestSafeRand_ConcurrentShuffle(t *testing.T) { + sr := NewSafeRand() + const goroutines = 10 + const sliceLen = 100 + done := make(chan struct{}, goroutines) + + for g := 0; g < goroutines; g++ { + go func() { + s := make([]int, sliceLen) + for i := range s { + s[i] = i + } + sr.Shuffle(len(s), func(i, j int) { + s[i], s[j] = s[j], s[i] + }) + done <- struct{}{} + }() + } + + for g := 0; g < goroutines; g++ { + <-done + } +} + +func TestSafeRand_Shuffle_EmptySlice(t *testing.T) { + sr := NewSafeRand() + emptySlice := []int{} + + // Should not panic with empty slice + sr.Shuffle(len(emptySlice), func(i, j int) { + emptySlice[i], emptySlice[j] = emptySlice[j], emptySlice[i] + }) + + if len(emptySlice) != 0 { + t.Errorf("Expected empty slice to remain empty, got length %d", len(emptySlice)) + } +} + +func TestSafeRand_Shuffle_SingleElement(t *testing.T) { + sr := NewSafeRand() + singleElement := []int{42} + + sr.Shuffle(len(singleElement), func(i, j int) { + singleElement[i], singleElement[j] = singleElement[j], singleElement[i] + }) + + if len(singleElement) != 1 || singleElement[0] != 42 { + t.Errorf("Expected single element slice to remain [42], got %v", singleElement) + } +} + +// Benchmark tests for safeRand. +func BenchmarkSafeRand_Shuffle_Small(b *testing.B) { + sr := NewSafeRand() + slice := make([]int, 10) + for i := range slice { + slice[i] = i + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sr.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) + } +} + +func BenchmarkSafeRand_Shuffle_Large(b *testing.B) { + sr := NewSafeRand() + slice := make([]int, 1000) + for i := range slice { + slice[i] = i + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + sr.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) + } +} + +func BenchmarkSafeRand_Shuffle_Concurrent(b *testing.B) { + sr := NewSafeRand() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + slice := make([]int, 100) + for i := range slice { + slice[i] = i + } + + for pb.Next() { + sr.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) + } + }) +} + +// Advanced concurrent tests for safeRand +func TestSafeRand_HighConcurrency(t *testing.T) { + sr := NewSafeRand() + const numGoroutines = 100 + const numOperations = 100 + const sliceSize = 50 + + done := make(chan struct{}, numGoroutines) + errors := make(chan error, numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func(goroutineID int) { + defer func() { + if r := recover(); r != nil { + errors <- fmt.Errorf("goroutine %d panicked: %v", goroutineID, r) + return + } + done <- struct{}{} + }() + + for op := 0; op < numOperations; op++ { + slice := make([]int, sliceSize) + for i := range slice { + slice[i] = i + } + + sr.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) + + // Verify all elements are still present + for expected := 0; expected < sliceSize; expected++ { + found := false + for _, actual := range slice { + if actual == expected { + found = true + break + } + } + if !found { + errors <- fmt.Errorf("goroutine %d: element %d missing after shuffle", goroutineID, expected) + return + } + } + } + }(g) + } + + // Wait for all goroutines to complete + for g := 0; g < numGoroutines; g++ { + select { + case <-done: + // Goroutine completed successfully + case err := <-errors: + t.Fatal(err) + } + } +} + +func TestSafeRand_StressTest(t *testing.T) { + sr := NewSafeRand() + const duration = 100 * time.Millisecond + const numGoroutines = 50 + + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + done := make(chan struct{}, numGoroutines) + errors := make(chan error, numGoroutines) + + for g := 0; g < numGoroutines; g++ { + go func(goroutineID int) { + defer func() { + if r := recover(); r != nil { + errors <- fmt.Errorf("goroutine %d panicked: %v", goroutineID, r) + return + } + done <- struct{}{} + }() + + for { + select { + case <-ctx.Done(): + return + default: + slice := make([]int, 20) + for i := range slice { + slice[i] = i + } + sr.Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) + } + } + }(g) + } + + // Wait for context timeout + <-ctx.Done() + + // Wait for all goroutines to complete + for g := 0; g < numGoroutines; g++ { + select { + case <-done: + // Goroutine completed successfully + case err := <-errors: + t.Fatal(err) + } + } +} diff --git a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go index a0e973fe6..d346792da 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go @@ -20,9 +20,7 @@ import ( "context" "encoding/json" "fmt" - "math/rand" "slices" - "time" "sigs.k8s.io/controller-runtime/pkg/log" @@ -60,13 +58,15 @@ func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker { return &MaxScorePicker{ typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType}, maxNumOfEndpoints: maxNumOfEndpoints, + rand: NewSafeRand(), // init with default safe random source. } } // MaxScorePicker picks pod(s) with the maximum score from the list of candidates. type MaxScorePicker struct { typedName plugins.TypedName - maxNumOfEndpoints int // maximum number of endpoints to pick + maxNumOfEndpoints int // maximum number of endpoints to pick + rand *safeRand // thread-safe random number generator } // WithName sets the picker's name @@ -85,13 +85,8 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState, log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints, len(scoredPods), scoredPods)) - // TODO: merge this with the logic in RandomPicker - // Rand package is not safe for concurrent use, so we create a new instance. - // Source: https://pkg.go.dev/math/rand#pkg-overview - randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) - // Shuffle in-place - needed for random tie break when scores are equal - randomGenerator.Shuffle(len(scoredPods), func(i, j int) { + p.rand.Shuffle(len(scoredPods), func(i, j int) { scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] }) diff --git a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go index b571ce3a5..7f38b3e23 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go @@ -20,8 +20,6 @@ import ( "context" "encoding/json" "fmt" - "math/rand" - "time" "sigs.k8s.io/controller-runtime/pkg/log" @@ -58,6 +56,7 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker { return &RandomPicker{ typedName: plugins.TypedName{Type: RandomPickerType, Name: RandomPickerType}, maxNumOfEndpoints: maxNumOfEndpoints, + rand: NewSafeRand(), // init with default safe random source. } } @@ -65,6 +64,7 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker { type RandomPicker struct { typedName plugins.TypedName maxNumOfEndpoints int + rand *safeRand // thread-safe random number generator } // WithName sets the name of the picker. @@ -83,13 +83,8 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates randomly: %+v", p.maxNumOfEndpoints, len(scoredPods), scoredPods)) - // TODO: merge this with the logic in MaxScorePicker - // Rand package is not safe for concurrent use, so we create a new instance. - // Source: https://pkg.go.dev/math/rand#pkg-overview - randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) - - // Shuffle in-place - randomGenerator.Shuffle(len(scoredPods), func(i, j int) { + // Shuffle in-place - needed for random tie break when scores are equal + p.rand.Shuffle(len(scoredPods), func(i, j int) { scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i] }) diff --git a/pkg/epp/util/testing/wrappers.go b/pkg/epp/util/testing/wrappers.go index cab70a445..4079c141a 100644 --- a/pkg/epp/util/testing/wrappers.go +++ b/pkg/epp/util/testing/wrappers.go @@ -19,6 +19,7 @@ package testing import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" )