Skip to content

[Optimization]: 🔨 add thread-safe wrapper around rand.Shuffle and improve performance. #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/epp/handlers/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
)

Expand Down
32 changes: 32 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ limitations under the License.

package picker

import (
"math/rand/v2"
"sync"
"time"
)

const (
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
)
Expand All @@ -24,3 +30,29 @@ const (
type pickerParameters struct {
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
}

// safeRand is a thread-safe wrapper around rand.Rand
// to ensure that random operations are safe to use in concurrent environments.
type safeRand struct {
p *sync.Pool
}

// NewSafeRand initializes a new safeRand.
func NewSafeRand() *safeRand {
p := &sync.Pool{
New: func() any {
seed := time.Now().UnixNano()
return rand.New(rand.NewPCG(uint64(seed), uint64(seed)))
},
}

return &safeRand{p: p}
}

// Shuffle is a thread-safe method to shuffle a slice.
func (s *safeRand) Shuffle(n int, swap func(i int, j int)) {
r := s.p.Get().(*rand.Rand)
defer s.p.Put(r)

r.Shuffle(n, swap)
}
253 changes: 253 additions & 0 deletions pkg/epp/scheduling/framework/plugins/picker/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package picker

import (
"context"
"fmt"
"testing"
"time"
)

func TestNewSafeRand_NotNil(t *testing.T) {
sr := NewSafeRand()
if sr == nil {
t.Fatal("NewSafeRand returned nil")
}
if sr.p == nil {
t.Fatal("sync.Pool in safeRand is nil")
}
}

func TestSafeRand_Shuffle_ShufflesSlice(t *testing.T) {
sr := NewSafeRand()
orig := []int{1, 2, 3, 4, 5}
shuffled := make([]int, len(orig))
copy(shuffled, orig)

sr.Shuffle(len(shuffled), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})

// It's possible (but unlikely) that the slice remains unchanged after shuffling.
// So we check that at least one element has moved.
same := true
for i := range orig {
if orig[i] != shuffled[i] {
same = false
break
}
}
if same {
t.Log("Warning: shuffled slice is the same as original (possible but unlikely)")
}
}

func TestSafeRand_ConcurrentShuffle(t *testing.T) {
sr := NewSafeRand()
const goroutines = 10
const sliceLen = 100
done := make(chan struct{}, goroutines)

for g := 0; g < goroutines; g++ {
go func() {
s := make([]int, sliceLen)
for i := range s {
s[i] = i
}
sr.Shuffle(len(s), func(i, j int) {
s[i], s[j] = s[j], s[i]
})
done <- struct{}{}
}()
}

for g := 0; g < goroutines; g++ {
<-done
}
}

func TestSafeRand_Shuffle_EmptySlice(t *testing.T) {
sr := NewSafeRand()
emptySlice := []int{}

// Should not panic with empty slice
sr.Shuffle(len(emptySlice), func(i, j int) {
emptySlice[i], emptySlice[j] = emptySlice[j], emptySlice[i]
})

if len(emptySlice) != 0 {
t.Errorf("Expected empty slice to remain empty, got length %d", len(emptySlice))
}
}

func TestSafeRand_Shuffle_SingleElement(t *testing.T) {
sr := NewSafeRand()
singleElement := []int{42}

sr.Shuffle(len(singleElement), func(i, j int) {
singleElement[i], singleElement[j] = singleElement[j], singleElement[i]
})

if len(singleElement) != 1 || singleElement[0] != 42 {
t.Errorf("Expected single element slice to remain [42], got %v", singleElement)
}
}

// Benchmark tests for safeRand.
func BenchmarkSafeRand_Shuffle_Small(b *testing.B) {
sr := NewSafeRand()
slice := make([]int, 10)
for i := range slice {
slice[i] = i
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
sr.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})
}
}

func BenchmarkSafeRand_Shuffle_Large(b *testing.B) {
sr := NewSafeRand()
slice := make([]int, 1000)
for i := range slice {
slice[i] = i
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
sr.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})
}
}

func BenchmarkSafeRand_Shuffle_Concurrent(b *testing.B) {
sr := NewSafeRand()

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
slice := make([]int, 100)
for i := range slice {
slice[i] = i
}

for pb.Next() {
sr.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})
}
})
}

// Advanced concurrent tests for safeRand
func TestSafeRand_HighConcurrency(t *testing.T) {
sr := NewSafeRand()
const numGoroutines = 100
const numOperations = 100
const sliceSize = 50

done := make(chan struct{}, numGoroutines)
errors := make(chan error, numGoroutines)

for g := 0; g < numGoroutines; g++ {
go func(goroutineID int) {
defer func() {
if r := recover(); r != nil {
errors <- fmt.Errorf("goroutine %d panicked: %v", goroutineID, r)
return
}
done <- struct{}{}
}()

for op := 0; op < numOperations; op++ {
slice := make([]int, sliceSize)
for i := range slice {
slice[i] = i
}

sr.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})

// Verify all elements are still present
for expected := 0; expected < sliceSize; expected++ {
found := false
for _, actual := range slice {
if actual == expected {
found = true
break
}
}
if !found {
errors <- fmt.Errorf("goroutine %d: element %d missing after shuffle", goroutineID, expected)
return
}
}
}
}(g)
}

// Wait for all goroutines to complete
for g := 0; g < numGoroutines; g++ {
select {
case <-done:
// Goroutine completed successfully
case err := <-errors:
t.Fatal(err)
}
}
}

func TestSafeRand_StressTest(t *testing.T) {
sr := NewSafeRand()
const duration = 100 * time.Millisecond
const numGoroutines = 50

ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()

done := make(chan struct{}, numGoroutines)
errors := make(chan error, numGoroutines)

for g := 0; g < numGoroutines; g++ {
go func(goroutineID int) {
defer func() {
if r := recover(); r != nil {
errors <- fmt.Errorf("goroutine %d panicked: %v", goroutineID, r)
return
}
done <- struct{}{}
}()

for {
select {
case <-ctx.Done():
return
default:
slice := make([]int, 20)
for i := range slice {
slice[i] = i
}
sr.Shuffle(len(slice), func(i, j int) {
slice[i], slice[j] = slice[j], slice[i]
})
}
}
}(g)
}

// Wait for context timeout
<-ctx.Done()

// Wait for all goroutines to complete
for g := 0; g < numGoroutines; g++ {
select {
case <-done:
// Goroutine completed successfully
case err := <-errors:
t.Fatal(err)
}
}
}
13 changes: 4 additions & 9 deletions pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"slices"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

Expand Down Expand Up @@ -60,13 +58,15 @@ func NewMaxScorePicker(maxNumOfEndpoints int) *MaxScorePicker {
return &MaxScorePicker{
typedName: plugins.TypedName{Type: MaxScorePickerType, Name: MaxScorePickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
rand: NewSafeRand(), // init with default safe random source.
}
}

// MaxScorePicker picks pod(s) with the maximum score from the list of candidates.
type MaxScorePicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int // maximum number of endpoints to pick
maxNumOfEndpoints int // maximum number of endpoints to pick
rand *safeRand // thread-safe random number generator
}

// WithName sets the picker's name
Expand All @@ -85,13 +85,8 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState,
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates sorted by max score: %+v", p.maxNumOfEndpoints,
len(scoredPods), scoredPods))

// TODO: merge this with the logic in RandomPicker
// Rand package is not safe for concurrent use, so we create a new instance.
// Source: https://pkg.go.dev/math/rand#pkg-overview
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))

// Shuffle in-place - needed for random tie break when scores are equal
randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
p.rand.Shuffle(len(scoredPods), func(i, j int) {
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
})

Expand Down
13 changes: 4 additions & 9 deletions pkg/epp/scheduling/framework/plugins/picker/random_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"

Expand Down Expand Up @@ -58,13 +56,15 @@ func NewRandomPicker(maxNumOfEndpoints int) *RandomPicker {
return &RandomPicker{
typedName: plugins.TypedName{Type: RandomPickerType, Name: RandomPickerType},
maxNumOfEndpoints: maxNumOfEndpoints,
rand: NewSafeRand(), // init with default safe random source.
}
}

// RandomPicker picks random pod(s) from the list of candidates.
type RandomPicker struct {
typedName plugins.TypedName
maxNumOfEndpoints int
rand *safeRand // thread-safe random number generator
}

// WithName sets the name of the picker.
Expand All @@ -83,13 +83,8 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates randomly: %+v", p.maxNumOfEndpoints,
len(scoredPods), scoredPods))

// TODO: merge this with the logic in MaxScorePicker
// Rand package is not safe for concurrent use, so we create a new instance.
// Source: https://pkg.go.dev/math/rand#pkg-overview
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))

// Shuffle in-place
randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
// Shuffle in-place - needed for random tie break when scores are equal
p.rand.Shuffle(len(scoredPods), func(i, j int) {
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
})

Expand Down
1 change: 1 addition & 0 deletions pkg/epp/util/testing/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package testing
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
)
Expand Down