Skip to content

Commit 36406ef

Browse files
authored
updated scheduler unit test (#229)
1 parent f83d798 commit 36406ef

File tree

1 file changed

+91
-142
lines changed

1 file changed

+91
-142
lines changed

pkg/scheduling/pd/scheduler_test.go

Lines changed: 91 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -6,97 +6,93 @@ import (
66

77
"github.com/go-logr/logr/testr"
88
"github.com/google/go-cmp/cmp"
9-
9+
"github.com/stretchr/testify/assert"
1010
k8stypes "k8s.io/apimachinery/pkg/types"
1111
"sigs.k8s.io/controller-runtime/pkg/log"
12-
"sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner"
1312
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
1413
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
15-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/common/config/loader"
1614
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
15+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
16+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
17+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
1718
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
18-
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
1919

20-
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins"
2120
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
21+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
22+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
23+
)
24+
25+
const (
26+
prefill = "prefill"
27+
decode = "decode"
2228
)
2329

24-
// Tests the default scheduler configuration and expected behavior.
30+
// Tests the scheduler expected behavior.
2531
func TestPDSchedule(t *testing.T) {
2632
pod1 := &types.PodMetrics{
2733
Pod: &backend.Pod{
2834
NamespacedName: k8stypes.NamespacedName{Name: "pod1"},
2935
Address: "1.2.3.4",
3036
Labels: map[string]string{filter.RoleLabel: filter.RolePrefill},
3137
},
32-
MetricsState: backendmetrics.NewMetricsState(),
38+
MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0},
3339
}
3440
pod2 := &types.PodMetrics{
3541
Pod: &backend.Pod{
3642
NamespacedName: k8stypes.NamespacedName{Name: "pod2"},
3743
Address: "5.6.7.8",
3844
Labels: map[string]string{filter.RoleLabel: filter.RoleDecode},
3945
},
40-
MetricsState: backendmetrics.NewMetricsState(),
46+
MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0},
4147
}
4248
noRolePod1 := &types.PodMetrics{
4349
Pod: &backend.Pod{
4450
NamespacedName: k8stypes.NamespacedName{Name: "noRolePod1"},
4551
Address: "1.1.1.1",
4652
},
47-
MetricsState: backendmetrics.NewMetricsState(),
48-
}
49-
noRolePod2 := &types.PodMetrics{
50-
Pod: &backend.Pod{
51-
NamespacedName: k8stypes.NamespacedName{Name: "noRolePod2"},
52-
Address: "2.2.2.2",
53-
},
54-
MetricsState: backendmetrics.NewMetricsState(),
53+
MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 2},
5554
}
5655

5756
prefillDecodeResult := &types.SchedulingResult{
5857
ProfileResults: map[string]*types.ProfileRunResult{
59-
"decode": {
58+
decode: {
6059
TargetPod: &types.ScoredPod{
6160
Pod: pod2,
62-
Score: 0.0,
61+
Score: 0.5,
6362
},
6463
},
65-
"prefill": {
64+
prefill: {
6665
TargetPod: &types.ScoredPod{
6766
Pod: pod1,
6867
Score: 0.0,
6968
},
7069
},
7170
},
72-
PrimaryProfileName: "decode",
71+
PrimaryProfileName: decode,
7372
}
7473

7574
decodeResult := &types.SchedulingResult{
7675
ProfileResults: map[string]*types.ProfileRunResult{
77-
"decode": {
76+
decode: {
7877
TargetPod: &types.ScoredPod{
7978
Pod: pod2,
80-
Score: 0.0,
79+
Score: 0.5,
8180
},
8281
},
8382
},
84-
PrimaryProfileName: "decode",
83+
PrimaryProfileName: decode,
8584
}
8685

8786
tests := []struct {
88-
name string
89-
req *types.LLMRequest
90-
input []types.Pod
91-
wantRes *types.SchedulingResult
92-
wantRes2 *types.SchedulingResult
93-
wantHeaders map[string]string
94-
unwantedHeaders []string
95-
unwantedPodIDs []string
96-
err bool
87+
name string
88+
req *types.LLMRequest
89+
input []types.Pod
90+
wantRes *types.SchedulingResult
91+
wantRes2 *types.SchedulingResult // a subsequent call to check prefix cache and how it affects PD
92+
err bool
9793
}{
9894
{
99-
name: "no pods in datastore",
95+
name: "no candidate pods",
10096
req: &types.LLMRequest{
10197
TargetModel: "any-model",
10298
Prompt: "12345678901",
@@ -111,17 +107,8 @@ func TestPDSchedule(t *testing.T) {
111107
Prompt: "12345678901",
112108
},
113109
// pod2 will be picked because it is the only pod with Decode role
114-
input: []types.Pod{pod2},
115-
wantRes: &types.SchedulingResult{
116-
ProfileResults: map[string]*types.ProfileRunResult{
117-
"decode": {
118-
TargetPod: &types.ScoredPod{
119-
Pod: pod2,
120-
},
121-
},
122-
},
123-
PrimaryProfileName: "decode",
124-
},
110+
input: []types.Pod{pod2},
111+
wantRes: decodeResult,
125112
},
126113
{
127114
name: "one prefill pod, long prompt",
@@ -139,7 +126,7 @@ func TestPDSchedule(t *testing.T) {
139126
TargetModel: "critical",
140127
Prompt: "12345678906",
141128
},
142-
// pod2 will be picked because it is the decode pod, pod1 IP will be in the header
129+
// pod2 will be picked in the decode profile result, pod1 will be in the prefill profile result
143130
input: []types.Pod{pod1, pod2},
144131
wantRes: prefillDecodeResult,
145132
wantRes2: decodeResult,
@@ -150,138 +137,100 @@ func TestPDSchedule(t *testing.T) {
150137
TargetModel: "critical",
151138
Prompt: "12345",
152139
},
153-
// pod2 will be picked because it is the decode pod, pod1 IP should no be in the header,
140+
// pod2 will be picked because it is the decode pod, pod1 shouldn't be picked,
154141
// because the prompt is too short
155142
input: []types.Pod{pod1, pod2},
156143
wantRes: decodeResult,
157144
wantRes2: decodeResult,
158145
},
159146
{
160-
name: "TestRoles",
147+
name: "TestRolesWithNoDecode",
161148
req: &types.LLMRequest{
162149
TargetModel: "critical",
163150
Prompt: "12345678901",
164151
},
165-
input: []types.Pod{pod1, noRolePod1, noRolePod2},
166-
wantRes: nil, // doesn't matter which pod was selected
167-
unwantedPodIDs: []string{pod1.GetPod().NamespacedName.String()},
152+
input: []types.Pod{pod1, noRolePod1},
153+
wantRes: &types.SchedulingResult{
154+
ProfileResults: map[string]*types.ProfileRunResult{
155+
decode: {
156+
TargetPod: &types.ScoredPod{
157+
Pod: noRolePod1,
158+
Score: 0.4921875,
159+
},
160+
},
161+
prefill: {
162+
TargetPod: &types.ScoredPod{
163+
Pod: pod1,
164+
Score: 0.0,
165+
},
166+
},
167+
},
168+
PrimaryProfileName: decode,
169+
},
170+
},
171+
{
172+
name: "1P2D - long prompt",
173+
req: &types.LLMRequest{
174+
TargetModel: "critical",
175+
Prompt: "12345678906",
176+
},
177+
// pod2 will be picked in the decode profile result cause it has higher score than noRolePod1
178+
// pod1 will be in the prefill profile result
179+
input: []types.Pod{pod1, pod2, noRolePod1},
180+
wantRes: prefillDecodeResult,
181+
wantRes2: decodeResult,
168182
},
169183
}
170184

171-
runner.RegisterAllPlugins()
172-
plugins.RegisterAllPlugins()
173-
174185
ctx := context.Background()
175186
logger := testr.New(t)
176187
ctx = log.IntoContext(ctx, logger)
177188

178189
for _, test := range tests {
179190
t.Run(test.name, func(t *testing.T) {
180-
handle := utils.NewTestHandle(ctx)
181-
182-
eppConfig, err := loader.LoadConfig([]byte(pdSchedulerConfigYaml), "")
183-
if err != nil {
184-
t.Errorf("Unexpected error, got %v", err)
185-
}
186-
187-
err = loader.LoadPluginReferences(eppConfig.Plugins, handle)
188-
if err != nil {
189-
t.Errorf("Unexpected error, got %v", err)
190-
}
191-
192-
schedulderConfig, err := loader.LoadSchedulerConfig(eppConfig.SchedulingProfiles, handle)
193-
if err != nil {
194-
t.Errorf("Unexpected error, got %v", err)
195-
}
196-
197-
scheduler := scheduling.NewSchedulerWithConfig(schedulderConfig)
191+
// initialize scheduler with config
192+
prefixScorer := prefix.New(prefix.Config{HashBlockSize: 5, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250})
193+
194+
prefillSchedulerProfile := framework.NewSchedulerProfile().
195+
WithFilters(filter.NewPrefillFilter()).
196+
WithPicker(picker.NewMaxScorePicker())
197+
err := prefillSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 50))
198+
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")
199+
200+
decodeSchedulerProfile := framework.NewSchedulerProfile().
201+
WithFilters(filter.NewDecodeFilter()).
202+
WithScorers(framework.NewWeightedScorer(scorer.NewLoadAwareScorer(scorer.QueueThresholdDefault), 1)).
203+
WithPicker(picker.NewMaxScorePicker())
204+
err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0))
205+
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")
206+
207+
profileHandle := profile.NewPdProfileHandler(prefill, decode, 10, 5)
208+
209+
schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{
210+
prefill: prefillSchedulerProfile,
211+
decode: decodeSchedulerProfile,
212+
})
213+
scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig)
198214
got, err := scheduler.Schedule(ctx, test.req, test.input)
199215

200216
if test.err != (err != nil) {
201217
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
202218
}
203219

204-
if test.wantRes != nil {
205-
if diff := cmp.Diff(test.wantRes, got); diff != "" {
206-
t.Errorf("Unexpected output (-want +got): %v", diff)
207-
}
208-
209-
for header, value := range test.wantHeaders {
210-
gotValue, ok := test.req.Headers[header]
211-
if !ok {
212-
t.Errorf("Missing header: %s", header)
213-
} else if gotValue != value {
214-
t.Errorf("Wrong header value for %s: want %s got %s)", header, value, gotValue)
215-
}
216-
}
217-
218-
for _, header := range test.unwantedHeaders {
219-
if _, exists := test.req.Headers[header]; exists {
220-
t.Errorf("Unwanted header %s exists", header)
221-
}
222-
}
223-
}
224-
225-
if len(test.unwantedPodIDs) > 0 {
226-
// ensure that target pod is not one of the unwanted
227-
profileRes, found := got.ProfileResults[got.PrimaryProfileName]
228-
if found {
229-
for _, podID := range test.unwantedPodIDs {
230-
if podID == profileRes.TargetPod.GetPod().NamespacedName.String() {
231-
t.Errorf("Unwanted pod was selected: %s", podID)
232-
}
233-
}
234-
}
220+
if diff := cmp.Diff(test.wantRes, got); diff != "" {
221+
t.Errorf("Unexpected output (-want +got): %v", diff)
235222
}
236223

237224
if test.wantRes2 != nil { // Checking the prefix match in the decode pod.
238225
got, err = scheduler.Schedule(ctx, test.req, test.input)
239226
if test.err != (err != nil) {
240-
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
227+
t.Errorf("Unexpected error in schedule call, got %v, want %v", err, test.err)
241228
}
242229

243230
if diff := cmp.Diff(test.wantRes2, got); diff != "" {
244-
t.Errorf("Unexpected output (-want +got): %v", diff)
231+
t.Errorf("Unexpected output in subsequent schedule call (-want +got): %v", diff)
245232
}
246233
}
247-
248234
})
249235
}
250236
}
251-
252-
const pdSchedulerConfigYaml = `
253-
apiVersion: inference.networking.x-k8s.io/v1alpha1
254-
kind: EndpointPickerConfig
255-
plugins:
256-
- type: prefill-header
257-
- name: prefixScorer
258-
type: prefix-cache
259-
parameters:
260-
hashBlockSize: 5
261-
maxPrefixBlocksToMatch: 256
262-
lruCapacityPerServer: 31250
263-
- name: prefillFilter
264-
type: prefill-filter
265-
- name: decodeFilter
266-
type: decode-filter
267-
- type: max-score
268-
- type: pd-profile-handler
269-
parameters:
270-
hashBlockSize: 5
271-
maxPrefixBlocksToMatch: 256
272-
lruCapacityPerServer: 31250
273-
threshold: 10
274-
schedulingProfiles:
275-
- name: prefill
276-
plugins:
277-
- pluginRef: prefillFilter
278-
- pluginRef: max-score
279-
- pluginRef: prefixScorer
280-
weight: 50
281-
- name: decode
282-
plugins:
283-
- pluginRef: decodeFilter
284-
- pluginRef: max-score
285-
- pluginRef: prefixScorer
286-
weight: 0
287-
`

0 commit comments

Comments
 (0)