From a8d9f4a4cd6433adfac256efcaf195c7c07e0521 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Thu, 4 Sep 2025 08:35:04 +0300 Subject: [PATCH] minor updates and godoc to weighted random picker Signed-off-by: Nir Rozenbaum --- .../plugins/picker/weighted_random_picker.go | 96 ++++++++----------- 1 file changed, 40 insertions(+), 56 deletions(-) diff --git a/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go index c12ab72b0..540ede43c 100644 --- a/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go +++ b/pkg/epp/scheduling/framework/plugins/picker/weighted_random_picker.go @@ -22,6 +22,7 @@ import ( "fmt" "math" "math/rand" + "slices" "sort" "time" @@ -43,12 +44,12 @@ type weightedScoredPod struct { key float64 } +// compile-time type validation var _ framework.Picker = &WeightedRandomPicker{} +// WeightedRandomPickerFactory defines the factory function for WeightedRandomPicker. func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { - parameters := pickerParameters{ - MaxNumOfEndpoints: DefaultMaxNumOfEndpoints, - } + parameters := pickerParameters{MaxNumOfEndpoints: DefaultMaxNumOfEndpoints} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { return nil, fmt.Errorf("failed to parse the parameters of the '%s' picker - %w", WeightedRandomPickerType, err) @@ -58,9 +59,10 @@ func WeightedRandomPickerFactory(name string, rawParameters json.RawMessage, _ p return NewWeightedRandomPicker(parameters.MaxNumOfEndpoints).WithName(name), nil } +// NewWeightedRandomPicker initializes a new WeightedRandomPicker and returns its pointer. func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker { if maxNumOfEndpoints <= 0 { - maxNumOfEndpoints = DefaultMaxNumOfEndpoints + maxNumOfEndpoints = DefaultMaxNumOfEndpoints // on invalid configuration value, fallback to default value } return &WeightedRandomPicker{ @@ -70,81 +72,68 @@ func NewWeightedRandomPicker(maxNumOfEndpoints int) *WeightedRandomPicker { } } +// WeightedRandomPicker picks pod(s) from the list of candidates based on weighted random sampling using A-Res algorithm. +// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf. +// +// The picker at its core is picking pods randomly, where the probability of the pod to get picked is derived +// from its weighted score. +// Algorithm: +// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ) +// - Selects k items with largest keys for mathematically correct weighted sampling +// - More efficient than traditional cumulative probability approach +// +// Key characteristics: +// - Mathematically correct weighted random sampling +// - Single pass algorithm with O(n + k log k) complexity type WeightedRandomPicker struct { typedName plugins.TypedName maxNumOfEndpoints int randomPicker *RandomPicker // fallback for zero weights } +// WithName sets the name of the picker. func (p *WeightedRandomPicker) WithName(name string) *WeightedRandomPicker { p.typedName.Name = name return p } +// TypedName returns the type and name tuple of this plugin instance. func (p *WeightedRandomPicker) TypedName() plugins.TypedName { return p.typedName } -// WeightedRandomPicker performs weighted random sampling using A-Res algorithm. -// Reference: https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf -// Algorithm: -// - Uses A-Res (Algorithm for Reservoir Sampling): keyᵢ = Uᵢ^(1/wᵢ) -// - Selects k items with largest keys for mathematically correct weighted sampling -// - More efficient than traditional cumulative probability approach -// -// Key characteristics: -// - Mathematically correct weighted random sampling -// - Single pass algorithm with O(n + k log k) complexity +// Pick selects the pod(s) randomly from the list of candidates, where the probability of the pod to get picked is derived +// from its weighted score. func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.ProfileRunResult { - log.FromContext(ctx).V(logutil.DEBUG).Info(fmt.Sprintf("Selecting maximum '%d' pods from %d candidates using weighted random sampling: %+v", - p.maxNumOfEndpoints, len(scoredPods), scoredPods)) - - // Check if all weights are zero or negative - allZeroWeights := true - for _, scoredPod := range scoredPods { - if scoredPod.Score > 0 { - allZeroWeights = false - break - } - } - - // Delegate to RandomPicker for uniform selection when all weights are zero - if allZeroWeights { - log.FromContext(ctx).V(logutil.DEBUG).Info("All weights are zero, delegating to RandomPicker for uniform selection") + // Check if there is at least one pod with Score > 0, if not let random picker run + if slices.IndexFunc(scoredPods, func(scoredPod *types.ScoredPod) bool { return scoredPod.Score > 0 }) == -1 { + log.FromContext(ctx).V(logutil.DEBUG).Info("All scores are zero, delegating to RandomPicker for uniform selection") return p.randomPicker.Pick(ctx, cycleState, scoredPods) } + log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates by random weighted picker", "max-num-of-endpoints", p.maxNumOfEndpoints, + "num-of-candidates", len(scoredPods), "scored-pods", scoredPods) + randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano())) // A-Res algorithm: keyᵢ = Uᵢ^(1/wᵢ) - weightedPods := make([]weightedScoredPod, 0, len(scoredPods)) - - for _, scoredPod := range scoredPods { - weight := float64(scoredPod.Score) - - // Handle zero or negative weights - if weight <= 0 { - // Assign very small key for zero-weight pods (effectively excludes them) - weightedPods = append(weightedPods, weightedScoredPod{ - ScoredPod: scoredPod, - key: 0, - }) + weightedPods := make([]weightedScoredPod, len(scoredPods)) + + for i, scoredPod := range scoredPods { + // Handle zero score + if scoredPod.Score <= 0 { + // Assign key=0 for zero-score pods (effectively excludes them from selection) + weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: 0} continue } - // Generate random number U in (0,1) + // If we're here the scoredPod.Score > 0. Generate a random number U in (0,1) u := randomGenerator.Float64() if u == 0 { u = 1e-10 // Avoid log(0) } - // Calculate key = U^(1/weight) - key := math.Pow(u, 1.0/weight) - - weightedPods = append(weightedPods, weightedScoredPod{ - ScoredPod: scoredPod, - key: key, - }) + weightedPods[i] = weightedScoredPod{ScoredPod: scoredPod, key: math.Pow(u, 1.0/scoredPod.Score)} // key = U^(1/weight) } // Sort by key in descending order (largest keys first) @@ -155,14 +144,9 @@ func (p *WeightedRandomPicker) Pick(ctx context.Context, cycleState *types.Cycle // Select top k pods selectedCount := min(p.maxNumOfEndpoints, len(weightedPods)) - scoredPods = make([]*types.ScoredPod, selectedCount) + targetPods := make([]types.Pod, selectedCount) for i := range selectedCount { - scoredPods[i] = weightedPods[i].ScoredPod - } - - targetPods := make([]types.Pod, len(scoredPods)) - for i, scoredPod := range scoredPods { - targetPods[i] = scoredPod + targetPods[i] = weightedPods[i].ScoredPod } return &types.ProfileRunResult{TargetPods: targetPods}