@@ -6,97 +6,93 @@ import (
66
77 "github.com/go-logr/logr/testr"
88 "github.com/google/go-cmp/cmp"
9-
9+ "github.com/stretchr/testify/assert"
1010 k8stypes "k8s.io/apimachinery/pkg/types"
1111 "sigs.k8s.io/controller-runtime/pkg/log"
12- "sigs.k8s.io/gateway-api-inference-extension/cmd/epp/runner"
1312 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
1413 backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
15- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/common/config/loader"
1614 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
15+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
16+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
17+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
1718 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
18- "sigs.k8s.io/gateway-api-inference-extension/test/utils"
1919
20- "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins"
2120 "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
21+ "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
22+ "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
23+ )
24+
25+ const (
26+ prefill = "prefill"
27+ decode = "decode"
2228)
2329
24- // Tests the default scheduler configuration and expected behavior.
30+ // Tests the scheduler expected behavior.
2531func TestPDSchedule (t * testing.T ) {
2632 pod1 := & types.PodMetrics {
2733 Pod : & backend.Pod {
2834 NamespacedName : k8stypes.NamespacedName {Name : "pod1" },
2935 Address : "1.2.3.4" ,
3036 Labels : map [string ]string {filter .RoleLabel : filter .RolePrefill },
3137 },
32- MetricsState : backendmetrics .NewMetricsState () ,
38+ MetricsState : & backendmetrics.MetricsState { WaitingQueueSize : 0 } ,
3339 }
3440 pod2 := & types.PodMetrics {
3541 Pod : & backend.Pod {
3642 NamespacedName : k8stypes.NamespacedName {Name : "pod2" },
3743 Address : "5.6.7.8" ,
3844 Labels : map [string ]string {filter .RoleLabel : filter .RoleDecode },
3945 },
40- MetricsState : backendmetrics .NewMetricsState () ,
46+ MetricsState : & backendmetrics.MetricsState { WaitingQueueSize : 0 } ,
4147 }
4248 noRolePod1 := & types.PodMetrics {
4349 Pod : & backend.Pod {
4450 NamespacedName : k8stypes.NamespacedName {Name : "noRolePod1" },
4551 Address : "1.1.1.1" ,
4652 },
47- MetricsState : backendmetrics .NewMetricsState (),
48- }
49- noRolePod2 := & types.PodMetrics {
50- Pod : & backend.Pod {
51- NamespacedName : k8stypes.NamespacedName {Name : "noRolePod2" },
52- Address : "2.2.2.2" ,
53- },
54- MetricsState : backendmetrics .NewMetricsState (),
53+ MetricsState : & backendmetrics.MetricsState {WaitingQueueSize : 2 },
5554 }
5655
5756 prefillDecodeResult := & types.SchedulingResult {
5857 ProfileResults : map [string ]* types.ProfileRunResult {
59- " decode" : {
58+ decode : {
6059 TargetPod : & types.ScoredPod {
6160 Pod : pod2 ,
62- Score : 0.0 ,
61+ Score : 0.5 ,
6362 },
6463 },
65- " prefill" : {
64+ prefill : {
6665 TargetPod : & types.ScoredPod {
6766 Pod : pod1 ,
6867 Score : 0.0 ,
6968 },
7069 },
7170 },
72- PrimaryProfileName : " decode" ,
71+ PrimaryProfileName : decode ,
7372 }
7473
7574 decodeResult := & types.SchedulingResult {
7675 ProfileResults : map [string ]* types.ProfileRunResult {
77- " decode" : {
76+ decode : {
7877 TargetPod : & types.ScoredPod {
7978 Pod : pod2 ,
80- Score : 0.0 ,
79+ Score : 0.5 ,
8180 },
8281 },
8382 },
84- PrimaryProfileName : " decode" ,
83+ PrimaryProfileName : decode ,
8584 }
8685
8786 tests := []struct {
88- name string
89- req * types.LLMRequest
90- input []types.Pod
91- wantRes * types.SchedulingResult
92- wantRes2 * types.SchedulingResult
93- wantHeaders map [string ]string
94- unwantedHeaders []string
95- unwantedPodIDs []string
96- err bool
87+ name string
88+ req * types.LLMRequest
89+ input []types.Pod
90+ wantRes * types.SchedulingResult
91+ wantRes2 * types.SchedulingResult // a subsequent call to check prefix cache and how it affects PD
92+ err bool
9793 }{
9894 {
99- name : "no pods in datastore " ,
95+ name : "no candidate pods " ,
10096 req : & types.LLMRequest {
10197 TargetModel : "any-model" ,
10298 Prompt : "12345678901" ,
@@ -111,17 +107,8 @@ func TestPDSchedule(t *testing.T) {
111107 Prompt : "12345678901" ,
112108 },
113109 // pod2 will be picked because it is the only pod with Decode role
114- input : []types.Pod {pod2 },
115- wantRes : & types.SchedulingResult {
116- ProfileResults : map [string ]* types.ProfileRunResult {
117- "decode" : {
118- TargetPod : & types.ScoredPod {
119- Pod : pod2 ,
120- },
121- },
122- },
123- PrimaryProfileName : "decode" ,
124- },
110+ input : []types.Pod {pod2 },
111+ wantRes : decodeResult ,
125112 },
126113 {
127114 name : "one prefill pod, long prompt" ,
@@ -139,7 +126,7 @@ func TestPDSchedule(t *testing.T) {
139126 TargetModel : "critical" ,
140127 Prompt : "12345678906" ,
141128 },
142- // pod2 will be picked because it is the decode pod , pod1 IP will be in the header
129+ // pod2 will be picked in the decode profile result , pod1 will be in the prefill profile result
143130 input : []types.Pod {pod1 , pod2 },
144131 wantRes : prefillDecodeResult ,
145132 wantRes2 : decodeResult ,
@@ -150,138 +137,100 @@ func TestPDSchedule(t *testing.T) {
150137 TargetModel : "critical" ,
151138 Prompt : "12345" ,
152139 },
153- // pod2 will be picked because it is the decode pod, pod1 IP should no be in the header ,
140+ // pod2 will be picked because it is the decode pod, pod1 shouldn't be picked ,
154141 // because the prompt is too short
155142 input : []types.Pod {pod1 , pod2 },
156143 wantRes : decodeResult ,
157144 wantRes2 : decodeResult ,
158145 },
159146 {
160- name : "TestRoles " ,
147+ name : "TestRolesWithNoDecode " ,
161148 req : & types.LLMRequest {
162149 TargetModel : "critical" ,
163150 Prompt : "12345678901" ,
164151 },
165- input : []types.Pod {pod1 , noRolePod1 , noRolePod2 },
166- wantRes : nil , // doesn't matter which pod was selected
167- unwantedPodIDs : []string {pod1 .GetPod ().NamespacedName .String ()},
152+ input : []types.Pod {pod1 , noRolePod1 },
153+ wantRes : & types.SchedulingResult {
154+ ProfileResults : map [string ]* types.ProfileRunResult {
155+ decode : {
156+ TargetPod : & types.ScoredPod {
157+ Pod : noRolePod1 ,
158+ Score : 0.4921875 ,
159+ },
160+ },
161+ prefill : {
162+ TargetPod : & types.ScoredPod {
163+ Pod : pod1 ,
164+ Score : 0.0 ,
165+ },
166+ },
167+ },
168+ PrimaryProfileName : decode ,
169+ },
170+ },
171+ {
172+ name : "1P2D - long prompt" ,
173+ req : & types.LLMRequest {
174+ TargetModel : "critical" ,
175+ Prompt : "12345678906" ,
176+ },
177+ // pod2 will be picked in the decode profile result cause it has higher score than noRolePod1
178+ // pod1 will be in the prefill profile result
179+ input : []types.Pod {pod1 , pod2 , noRolePod1 },
180+ wantRes : prefillDecodeResult ,
181+ wantRes2 : decodeResult ,
168182 },
169183 }
170184
171- runner .RegisterAllPlugins ()
172- plugins .RegisterAllPlugins ()
173-
174185 ctx := context .Background ()
175186 logger := testr .New (t )
176187 ctx = log .IntoContext (ctx , logger )
177188
178189 for _ , test := range tests {
179190 t .Run (test .name , func (t * testing.T ) {
180- handle := utils .NewTestHandle (ctx )
181-
182- eppConfig , err := loader .LoadConfig ([]byte (pdSchedulerConfigYaml ), "" )
183- if err != nil {
184- t .Errorf ("Unexpected error, got %v" , err )
185- }
186-
187- err = loader .LoadPluginReferences (eppConfig .Plugins , handle )
188- if err != nil {
189- t .Errorf ("Unexpected error, got %v" , err )
190- }
191-
192- schedulderConfig , err := loader .LoadSchedulerConfig (eppConfig .SchedulingProfiles , handle )
193- if err != nil {
194- t .Errorf ("Unexpected error, got %v" , err )
195- }
196-
197- scheduler := scheduling .NewSchedulerWithConfig (schedulderConfig )
191+ // initialize scheduler with config
192+ prefixScorer := prefix .New (prefix.Config {HashBlockSize : 5 , MaxPrefixBlocksToMatch : 256 , LRUCapacityPerServer : 31250 })
193+
194+ prefillSchedulerProfile := framework .NewSchedulerProfile ().
195+ WithFilters (filter .NewPrefillFilter ()).
196+ WithPicker (picker .NewMaxScorePicker ())
197+ err := prefillSchedulerProfile .AddPlugins (framework .NewWeightedScorer (prefixScorer , 50 ))
198+ assert .NoError (t , err , "SchedulerProfile AddPlugins returned unexpected error" )
199+
200+ decodeSchedulerProfile := framework .NewSchedulerProfile ().
201+ WithFilters (filter .NewDecodeFilter ()).
202+ WithScorers (framework .NewWeightedScorer (scorer .NewLoadAwareScorer (scorer .QueueThresholdDefault ), 1 )).
203+ WithPicker (picker .NewMaxScorePicker ())
204+ err = decodeSchedulerProfile .AddPlugins (framework .NewWeightedScorer (prefixScorer , 0 ))
205+ assert .NoError (t , err , "SchedulerProfile AddPlugins returned unexpected error" )
206+
207+ profileHandle := profile .NewPdProfileHandler (prefill , decode , 10 , 5 )
208+
209+ schedulerConfig := scheduling .NewSchedulerConfig (profileHandle , map [string ]* framework.SchedulerProfile {
210+ prefill : prefillSchedulerProfile ,
211+ decode : decodeSchedulerProfile ,
212+ })
213+ scheduler := scheduling .NewSchedulerWithConfig (schedulerConfig )
198214 got , err := scheduler .Schedule (ctx , test .req , test .input )
199215
200216 if test .err != (err != nil ) {
201217 t .Errorf ("Unexpected error, got %v, want %v" , err , test .err )
202218 }
203219
204- if test .wantRes != nil {
205- if diff := cmp .Diff (test .wantRes , got ); diff != "" {
206- t .Errorf ("Unexpected output (-want +got): %v" , diff )
207- }
208-
209- for header , value := range test .wantHeaders {
210- gotValue , ok := test .req .Headers [header ]
211- if ! ok {
212- t .Errorf ("Missing header: %s" , header )
213- } else if gotValue != value {
214- t .Errorf ("Wrong header value for %s: want %s got %s)" , header , value , gotValue )
215- }
216- }
217-
218- for _ , header := range test .unwantedHeaders {
219- if _ , exists := test .req .Headers [header ]; exists {
220- t .Errorf ("Unwanted header %s exists" , header )
221- }
222- }
223- }
224-
225- if len (test .unwantedPodIDs ) > 0 {
226- // ensure that target pod is not one of the unwanted
227- profileRes , found := got .ProfileResults [got .PrimaryProfileName ]
228- if found {
229- for _ , podID := range test .unwantedPodIDs {
230- if podID == profileRes .TargetPod .GetPod ().NamespacedName .String () {
231- t .Errorf ("Unwanted pod was selected: %s" , podID )
232- }
233- }
234- }
220+ if diff := cmp .Diff (test .wantRes , got ); diff != "" {
221+ t .Errorf ("Unexpected output (-want +got): %v" , diff )
235222 }
236223
237224 if test .wantRes2 != nil { // Checking the prefix match in the decode pod.
238225 got , err = scheduler .Schedule (ctx , test .req , test .input )
239226 if test .err != (err != nil ) {
240- t .Errorf ("Unexpected error, got %v, want %v" , err , test .err )
227+ t .Errorf ("Unexpected error in schedule call , got %v, want %v" , err , test .err )
241228 }
242229
243230 if diff := cmp .Diff (test .wantRes2 , got ); diff != "" {
244- t .Errorf ("Unexpected output (-want +got): %v" , diff )
231+ t .Errorf ("Unexpected output in subsequent schedule call (-want +got): %v" , diff )
245232 }
246233 }
247-
248234 })
249235 }
250236}
251-
252- const pdSchedulerConfigYaml = `
253- apiVersion: inference.networking.x-k8s.io/v1alpha1
254- kind: EndpointPickerConfig
255- plugins:
256- - type: prefill-header
257- - name: prefixScorer
258- type: prefix-cache
259- parameters:
260- hashBlockSize: 5
261- maxPrefixBlocksToMatch: 256
262- lruCapacityPerServer: 31250
263- - name: prefillFilter
264- type: prefill-filter
265- - name: decodeFilter
266- type: decode-filter
267- - type: max-score
268- - type: pd-profile-handler
269- parameters:
270- hashBlockSize: 5
271- maxPrefixBlocksToMatch: 256
272- lruCapacityPerServer: 31250
273- threshold: 10
274- schedulingProfiles:
275- - name: prefill
276- plugins:
277- - pluginRef: prefillFilter
278- - pluginRef: max-score
279- - pluginRef: prefixScorer
280- weight: 50
281- - name: decode
282- plugins:
283- - pluginRef: decodeFilter
284- - pluginRef: max-score
285- - pluginRef: prefixScorer
286- weight: 0
287- `
0 commit comments