Skip to content

Commit 4f39530

Browse files
authored
fix: add validation for load-aware scorer to handle invalid queue thresholds (#240)
Signed-off-by: Kfir Toledo <[email protected]>
1 parent 64465c0 commit 4f39530

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

pkg/plugins/scorer/load_aware_scorer.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ import (
55
"encoding/json"
66
"fmt"
77

8+
"sigs.k8s.io/controller-runtime/pkg/log"
89
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
910
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
1011
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
12+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
1113
)
1214

1315
const (
@@ -26,19 +28,24 @@ type loadAwareScorerParameters struct {
2628
var _ framework.Scorer = &LoadAwareScorer{}
2729

2830
// LoadAwareScorerFactory defines the factory function for the LoadAwareScorer
29-
func LoadAwareScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
31+
func LoadAwareScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
3032
parameters := loadAwareScorerParameters{Threshold: QueueThresholdDefault}
3133
if rawParameters != nil {
3234
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
3335
return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", LoadAwareScorerType, err)
3436
}
3537
}
3638

37-
return NewLoadAwareScorer(parameters.Threshold).WithName(name), nil
39+
return NewLoadAwareScorer(handle.Context(), parameters.Threshold).WithName(name), nil
3840
}
3941

4042
// NewLoadAwareScorer creates a new load based scorer
41-
func NewLoadAwareScorer(queueThreshold int) *LoadAwareScorer {
43+
func NewLoadAwareScorer(ctx context.Context, queueThreshold int) *LoadAwareScorer {
44+
if queueThreshold <= 0 {
45+
queueThreshold = QueueThresholdDefault
46+
log.FromContext(ctx).V(logutil.DEFAULT).Info(fmt.Sprintf("queueThreshold %d should be positive, using default queue threshold %d", queueThreshold, QueueThresholdDefault))
47+
}
48+
4249
return &LoadAwareScorer{
4350
typedName: plugins.TypedName{Type: LoadAwareScorerType},
4451
queueThreshold: float64(queueThreshold),

pkg/plugins/scorer/load_aware_scorer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestLoadBasedScorer(t *testing.T) {
2727
}{
2828
{
2929
name: "load based scorer",
30-
scorer: scorer.NewLoadAwareScorer(0),
30+
scorer: scorer.NewLoadAwareScorer(context.Background(), 10),
3131
req: &types.LLMRequest{
3232
TargetModel: "critical",
3333
},

pkg/scheduling/pd/scheduler_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func TestPDSchedule(t *testing.T) {
195195

196196
decodeSchedulerProfile := framework.NewSchedulerProfile().
197197
WithFilters(filter.NewDecodeFilter()).
198-
WithScorers(framework.NewWeightedScorer(scorer.NewLoadAwareScorer(scorer.QueueThresholdDefault), 1)).
198+
WithScorers(framework.NewWeightedScorer(scorer.NewLoadAwareScorer(ctx, scorer.QueueThresholdDefault), 1)).
199199
WithPicker(picker.NewMaxScorePicker())
200200
err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0))
201201
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")

0 commit comments

Comments
 (0)