Skip to content

Commit e95d8e8

Browse files
authored
Merge shuffle score pods logic (#1552)
1 parent bd4bf22 commit e95d8e8

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

pkg/epp/scheduling/framework/plugins/picker/common.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ limitations under the License.
1616

1717
package picker
1818

19+
import (
20+
"math/rand/v2"
21+
"time"
22+
23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
24+
)
25+
1926
const (
2027
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
2128
)
@@ -24,3 +31,14 @@ const (
2431
type pickerParameters struct {
2532
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
2633
}
34+
35+
func shuffleScoredPods(scoredPods []*types.ScoredPod) {
36+
// Rand package is not safe for concurrent use, so we create a new instance.
37+
// Source: https://pkg.go.dev/math/rand/v2#pkg-overview
38+
randomGenerator := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0))
39+
40+
// Shuffle in-place
41+
randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
42+
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
43+
})
44+
}

pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23-
"math/rand"
2423
"slices"
25-
"time"
2624

2725
"sigs.k8s.io/controller-runtime/pkg/log"
28-
2926
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3027
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
3128
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -85,15 +82,8 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState,
8582
log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates sorted by max score", "max-num-of-endpoints", p.maxNumOfEndpoints,
8683
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)
8784

88-
// TODO: merge this with the logic in RandomPicker
89-
// Rand package is not safe for concurrent use, so we create a new instance.
90-
// Source: https://pkg.go.dev/math/rand#pkg-overview
91-
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
92-
9385
// Shuffle in-place - needed for random tie break when scores are equal
94-
randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
95-
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
96-
})
86+
shuffleScoredPods(scoredPods)
9787

9888
slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
9989
if i.Score > j.Score {

pkg/epp/scheduling/framework/plugins/picker/random_picker.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23-
"math/rand"
24-
"time"
2523

2624
"sigs.k8s.io/controller-runtime/pkg/log"
27-
2825
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
2926
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
3027
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -84,15 +81,8 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods
8481
log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates randomly", "max-num-of-endpoints", p.maxNumOfEndpoints,
8582
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)
8683

87-
// TODO: merge this with the logic in MaxScorePicker
88-
// Rand package is not safe for concurrent use, so we create a new instance.
89-
// Source: https://pkg.go.dev/math/rand#pkg-overview
90-
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
91-
9284
// Shuffle in-place
93-
randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
94-
scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
95-
})
85+
shuffleScoredPods(scoredPods)
9686

9787
// if we have enough pods to return keep only the relevant subset
9888
if p.maxNumOfEndpoints < len(scoredPods) {

0 commit comments

Comments
 (0)