Skip to content

Commit c8e11a8

Browse files
Joohonirrozenbaum
authored andcommitted
Add WeightedRandomPicker (#1412)
* Add WeightedRandomPicker with score normalization Implements weighted random sampling picker with 4 normalization options: - None (default): Pure weighted random sampling - Square Root: Moderate score difference reduction - Capping: Limits max score ratio to prevent hot-spotting - Logarithmic: Maximum load distribution for extreme score variations Addresses hot-spotting issues in max-score-picker while maintaining score-based preferences through configurable normalization strategies. Signed-off-by: Jooho Lee <[email protected]> * remove normalization and add topN logic Signed-off-by: Jooho Lee <[email protected]> * removed all the optimization parts and implemented only the core WRS logic Signed-off-by: Jooho Lee <[email protected]> * Change to delegate selection to RandomPicker for uniform picking when all weights are zero Signed-off-by: Jooho Lee <[email protected]> * follw up comments Signed-off-by: Jooho Lee <[email protected]> --------- Signed-off-by: Jooho Lee <[email protected]>
1 parent a1a4b8a commit c8e11a8

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

cmd/epp/runner/runner.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ func (r *Runner) registerInTreePlugins() {
301301
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
302302
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
303303
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
304+
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
304305
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
305306
plugins.Register(scorer.KvCacheUtilizationScorerType, scorer.KvCacheUtilizationScorerFactory)
306307
plugins.Register(scorer.QueueScorerType, scorer.QueueScorerFactory)

pkg/epp/config/loader/configloader_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ func registerNeededPlgugins() {
443443
plugins.Register(prefix.PrefixCachePluginType, prefix.PrefixCachePluginFactory)
444444
plugins.Register(picker.MaxScorePickerType, picker.MaxScorePickerFactory)
445445
plugins.Register(picker.RandomPickerType, picker.RandomPickerFactory)
446+
plugins.Register(picker.WeightedRandomPickerType, picker.WeightedRandomPickerFactory)
446447
plugins.Register(profile.SingleProfileHandlerType, profile.SingleProfileHandlerFactory)
447448
}
448449

pkg/epp/scheduling/framework/plugins/picker/picker_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,120 @@ func TestPickMaxScorePicker(t *testing.T) {
135135
})
136136
}
137137
}
138+
139+
func TestPickWeightedRandomPicker(t *testing.T) {
140+
const (
141+
testIterations = 1000
142+
tolerance = 0.2 // 20% tolerance in [0,1] range
143+
)
144+
145+
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
146+
pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
147+
pod3 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}
148+
pod4 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod4"}}}
149+
pod5 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod5"}}}
150+
151+
// A-Res algorithm uses U^(1/w) transformation which introduces statistical variance
152+
// beyond simple proportional sampling. Generous tolerance is required to prevent
153+
// flaky tests in CI environments, especially for multi-tier weights.
154+
tests := []struct {
155+
name string
156+
input []*types.ScoredPod
157+
maxPods int // maxNumOfEndpoints for this test
158+
}{
159+
{
160+
name: "High weight dominance test",
161+
input: []*types.ScoredPod{
162+
{Pod: pod1, Score: 10}, // Lower weight
163+
{Pod: pod2, Score: 90}, // Higher weight (should dominate)
164+
},
165+
maxPods: 1,
166+
},
167+
{
168+
name: "Equal weights test - A-Res uniform distribution",
169+
input: []*types.ScoredPod{
170+
{Pod: pod1, Score: 100}, // Equal weights (higher values for better numerical precision)
171+
{Pod: pod2, Score: 100}, // Equal weights should yield uniform distribution
172+
{Pod: pod3, Score: 100}, // Equal weights in A-Res
173+
},
174+
maxPods: 1,
175+
},
176+
{
177+
name: "Zero weight exclusion test - A-Res edge case",
178+
input: []*types.ScoredPod{
179+
{Pod: pod1, Score: 30}, // Normal weight, should be selected
180+
{Pod: pod2, Score: 0}, // Zero weight, never selected in A-Res
181+
},
182+
maxPods: 1,
183+
},
184+
{
185+
name: "Multi-tier weighted test - A-Res complex distribution",
186+
input: []*types.ScoredPod{
187+
{Pod: pod1, Score: 100}, // Highest weight
188+
{Pod: pod2, Score: 90}, // High weight
189+
{Pod: pod3, Score: 50}, // Medium weight
190+
{Pod: pod4, Score: 30}, // Low weight
191+
{Pod: pod5, Score: 20}, // Lowest weight
192+
},
193+
maxPods: 1,
194+
},
195+
}
196+
197+
for _, test := range tests {
198+
t.Run(test.name, func(t *testing.T) {
199+
picker := NewWeightedRandomPicker(test.maxPods)
200+
selectionCounts := make(map[string]int)
201+
202+
// Calculate expected probabilities based on scores
203+
totalScore := 0.0
204+
for _, pod := range test.input {
205+
totalScore += pod.Score
206+
}
207+
208+
expectedProbabilities := make(map[string]float64)
209+
for _, pod := range test.input {
210+
podName := pod.GetPod().NamespacedName.Name
211+
if totalScore > 0 {
212+
expectedProbabilities[podName] = pod.Score / totalScore
213+
} else {
214+
expectedProbabilities[podName] = 0.0
215+
}
216+
}
217+
218+
// Initialize selection counters for each pod
219+
for _, pod := range test.input {
220+
podName := pod.GetPod().NamespacedName.Name
221+
selectionCounts[podName] = 0
222+
}
223+
224+
// Run multiple iterations to gather statistical data
225+
for i := 0; i < testIterations; i++ {
226+
result := picker.Pick(context.Background(), types.NewCycleState(), test.input)
227+
228+
// Count selections for probability analysis
229+
if len(result.TargetPods) > 0 {
230+
selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name
231+
selectionCounts[selectedPodName]++
232+
}
233+
}
234+
235+
// Verify probability distribution
236+
for podName, expectedProb := range expectedProbabilities {
237+
actualCount := selectionCounts[podName]
238+
actualProb := float64(actualCount) / float64(testIterations)
239+
240+
toleranceValue := expectedProb * tolerance
241+
lowerBound := expectedProb - toleranceValue
242+
upperBound := expectedProb + toleranceValue
243+
244+
if actualProb < lowerBound || actualProb > upperBound {
245+
t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)",
246+
podName, expectedProb, tolerance*100, actualProb, actualCount, testIterations)
247+
} else {
248+
t.Logf("Pod %s: expected %.3f, got %.3f (count: %d/%d) ✓",
249+
podName, expectedProb, actualProb, actualCount, testIterations)
250+
}
251+
}
252+
})
253+
}
254+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package picker
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
"fmt"
23+
"math"
24+
"math/rand"
25+
"sort"
26+
"time"
27+
28+
"sigs.k8s.io/controller-runtime/pkg/log"
29+
30+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
31+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
32+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
33+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34+
)
35+
36+
const (
37+
WeightedRandomPickerType = "weighted-random-picker"
38+
)
39+
40+
// weightedScoredPod represents a scored pod with its A-Res sampling key
41+
type weightedScoredPod struct {
42+
*types.ScoredPod
43+
key float64
44+
}
45+
46+
var _ framework.Picker = &WeightedRandomPicker{}
47+
48+
func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
49+
parameters := pickerParameters{
50+
MaxNumOfEndpoints: DefaultMaxNumOfEndpoints,
51+
}
52+
if rawParameters != nil {
53+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
54+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err)
55+
}
56+
}
57+
58+
return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil
59+
}
60+
61+
func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker {
62+
if maxNumOfEndpoints <= 0 {
63+
maxNumOfEndpoints = DefaultMaxNumOfEndpoints
64+
}
65+
66+
return &WeightedRandomPicker{
67+
typedName: plugins.TypedName{Type: WeightedRandomPickerType, Name: WeightedRandomPickerType},
68+
maxNumOfEndpoints: maxNumOfEndpoints,
69+
randomPicker: NewRandomPicker(maxNumOfEndpoints),
70+
}
71+
}
72+
73+
type WeightedRandomPicker struct {
74+
typedName plugins.TypedName
75+
maxNumOfEndpoints int
76+
randomPicker *RandomPicker // fallback for zero weights
77+
}
78+
79+
func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker {
80+
p.typedName.Name = name
81+
return p
82+
}
83+
84+
func (p *WeightedRandomPicker) TypedName() plugins.TypedName {
85+
return p.typedName
86+
}
87+
88+
// WeightedRandomPicker performs weighted random sampling using A-Res algorithm.
89+
// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf
90+
// Algorithm:
91+
// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ)
92+
// - Selects k items with largest keys for mathematically correct weighted sampling
93+
// - More efficient than traditional cumulative probability approach
94+
//
95+
// Key characteristics:
96+
// - Mathematically correct weighted random sampling
97+
// - Single pass algorithm with O(n + k log k) complexity
98+
func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult {
99+
log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v",
100+
p.maxNumOfEndpoints, len(scoredPods), scoredPods))
101+
102+
// Check if all weights are zero or negative
103+
allZeroWeights := true
104+
for _, scoredPod := range scoredPods {
105+
if scoredPod.Score > 0 {
106+
allZeroWeights = false
107+
break
108+
}
109+
}
110+
111+
// Delegate to RandomPicker for uniform selection when all weights are zero
112+
if allZeroWeights {
113+
log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection")
114+
return p.randomPicker.Pick(ctx, cycleState, scoredPods)
115+
}
116+
117+
randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
118+
119+
// A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ)
120+
weightedPods := make([]weightedScoredPod, 0, len(scoredPods))
121+
122+
for _, scoredPod := range scoredPods {
123+
weight := float64(scoredPod.Score)
124+
125+
// Handle zero or negative weights
126+
if weight <= 0 {
127+
// Assign very small key for zero-weight pods (effectively excludes them)
128+
weightedPods = append(weightedPods, weightedScoredPod{
129+
ScoredPod: scoredPod,
130+
key: 0,
131+
})
132+
continue
133+
}
134+
135+
// Generate random number U in (0,1)
136+
u := randomGenerator.Float64()
137+
if u == 0 {
138+
u = 1e-10 // Avoid log(0)
139+
}
140+
141+
// Calculate key = U^(1/weight)
142+
key := math.Pow(u, 1.0/weight)
143+
144+
weightedPods = append(weightedPods, weightedScoredPod{
145+
ScoredPod: scoredPod,
146+
key: key,
147+
})
148+
}
149+
150+
// Sort by key in descending order (largest keys first)
151+
sort.Slice(weightedPods, func(i, j int) bool {
152+
return weightedPods[i].key > weightedPods[j].key
153+
})
154+
155+
// Select top k pods
156+
selectedCount := min(p.maxNumOfEndpoints, len(weightedPods))
157+
158+
scoredPods = make([]*types.ScoredPod, selectedCount)
159+
for i := range selectedCount {
160+
scoredPods[i] = weightedPods[i].ScoredPod
161+
}
162+
163+
targetPods := make([]types.Pod, len(scoredPods))
164+
for i, scoredPod := range scoredPods {
165+
targetPods[i] = scoredPod
166+
}
167+
168+
return &types.ProfileRunResult{TargetPods: targetPods}
169+
}

0 commit comments

Comments
 (0)