@@ -24,10 +24,12 @@ import (
2424 "strings"
2525 "testing"
2626
27+ "github.com/google/uuid"
2728 "github.com/stretchr/testify/assert"
2829 k8stypes "k8s.io/apimachinery/pkg/types"
2930
3031 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
32+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
3133 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3234)
3335
@@ -38,20 +40,20 @@ func TestPrefixPlugin(t *testing.T) {
3840 MaxPrefixBlocksToMatch : DefaultMaxPrefixBlocks ,
3941 LRUCapacityPerServer : DefaultLRUCapacityPerServer ,
4042 }
41- plugin := New (config )
43+ plugin := New (context . Background (), config )
4244
4345 pod1 := & types.PodMetrics {Pod : & backend.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod1" }}}
4446 pod2 := & types.PodMetrics {Pod : & backend.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod2" }}}
4547 pods := []types.Pod {pod1 , pod2 }
4648
4749 // First request.
4850 req1 := & types.LLMRequest {
51+ RequestId : uuid .NewString (),
4952 TargetModel : "test-model1" ,
5053 Prompt : "aaaaaa" ,
5154 }
52- cycleState1 := types .NewCycleState ()
53- scores := plugin .Score (context .Background (), cycleState1 , req1 , pods )
54- state , err := types .ReadCycleStateKey [* SchedulingContextState ](cycleState1 , PrefixCachePluginType )
55+ scores := plugin .Score (context .Background (), nil , req1 , pods )
56+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](plugin .pluginState , req1 .RequestId , PrefixCachePluginType )
5557 assert .NoError (t , err )
5658 t .Logf ("Hashes %+v, cached servers: %+v" , state .PrefixHashes , state .PrefixCacheServers )
5759 // Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -62,17 +64,23 @@ func TestPrefixPlugin(t *testing.T) {
6264 assert .Equal (t , float64 (0 ), scores [pod2 ], "score for pod2" )
6365
6466 // Simulate pod1 was picked.
65- plugin .PostCycle (context .Background (), cycleState1 , & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }})
67+ schedulingResult := & types.SchedulingResult {
68+ PrimaryProfileName : "default" ,
69+ ProfileResults : map [string ]* types.ProfileRunResult {
70+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }},
71+ },
72+ }
73+ plugin .PreRequest (context .Background (), req1 , schedulingResult , 0 )
6674
6775 // Second request doesn't share any prefix with first one. It should be added to the cache but
6876 // the pod score should be 0.
6977 req2 := & types.LLMRequest {
78+ RequestId : uuid .NewString (),
7079 TargetModel : "test-model2" ,
7180 Prompt : "bbbbbb" ,
7281 }
73- cycleState2 := types .NewCycleState ()
74- scores = plugin .Score (context .Background (), cycleState2 , req2 , pods )
75- state , err = types .ReadCycleStateKey [* SchedulingContextState ](cycleState2 , PrefixCachePluginType )
82+ scores = plugin .Score (context .Background (), nil , req2 , pods )
83+ state , err = plugins .ReadPluginStateKey [* SchedulingContextState ](plugin .pluginState , req2 .RequestId , PrefixCachePluginType )
7684 assert .NoError (t , err )
7785 t .Logf ("Hashes %+v, cached servers: %+v" , state .PrefixHashes , state .PrefixCacheServers )
7886 // Input size is 6, hash block size is 4, the last 2 characters are ignored.
@@ -83,16 +91,22 @@ func TestPrefixPlugin(t *testing.T) {
8391 assert .Equal (t , float64 (0 ), scores [pod2 ], "score for pod2" )
8492
8593 // Simulate pod2 was picked.
86- plugin .PostCycle (context .Background (), cycleState2 , & types.ProfileRunResult {TargetPods : []types.Pod {pod2 }})
94+ schedulingResult = & types.SchedulingResult {
95+ PrimaryProfileName : "default" ,
96+ ProfileResults : map [string ]* types.ProfileRunResult {
97+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod2 }},
98+ },
99+ }
100+ plugin .PreRequest (context .Background (), req2 , schedulingResult , 0 )
87101
88102 // Third request shares partial prefix with first one.
89103 req3 := & types.LLMRequest {
104+ RequestId : uuid .NewString (),
90105 TargetModel : "test-model1" ,
91106 Prompt : "aaaabbbb" ,
92107 }
93- cycleState3 := types .NewCycleState ()
94- scores = plugin .Score (context .Background (), cycleState3 , req3 , pods )
95- state , err = types .ReadCycleStateKey [* SchedulingContextState ](cycleState3 , PrefixCachePluginType )
108+ scores = plugin .Score (context .Background (), nil , req3 , pods )
109+ state , err = plugins .ReadPluginStateKey [* SchedulingContextState ](plugin .pluginState , req3 .RequestId , PrefixCachePluginType )
96110 assert .NoError (t , err )
97111 t .Logf ("Hashes %+v, cached servers: %+v" , state .PrefixHashes , state .PrefixCacheServers )
98112 // Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -102,16 +116,22 @@ func TestPrefixPlugin(t *testing.T) {
102116 assert .Equal (t , float64 (2 )/ float64 (3 ), scores [pod1 ], "score should be 2/3 - the model and the first prefix block match" )
103117 assert .Equal (t , float64 (0 ), scores [pod2 ], "score for pod2" )
104118
105- plugin .PostCycle (context .Background (), cycleState3 , & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }})
119+ schedulingResult = & types.SchedulingResult {
120+ PrimaryProfileName : "default" ,
121+ ProfileResults : map [string ]* types.ProfileRunResult {
122+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }},
123+ },
124+ }
125+ plugin .PreRequest (context .Background (), req3 , schedulingResult , 0 )
106126
107127 // 4th request is same as req3 except the model is different, still no match.
108128 req4 := & types.LLMRequest {
129+ RequestId : uuid .NewString (),
109130 TargetModel : "test-model-new" ,
110131 Prompt : "aaaabbbb" ,
111132 }
112- cycleState4 := types .NewCycleState ()
113- scores = plugin .Score (context .Background (), cycleState4 , req4 , pods )
114- state , err = types .ReadCycleStateKey [* SchedulingContextState ](cycleState4 , PrefixCachePluginType )
133+ scores = plugin .Score (context .Background (), nil , req4 , pods )
134+ state , err = plugins .ReadPluginStateKey [* SchedulingContextState ](plugin .pluginState , req4 .RequestId , PrefixCachePluginType )
115135 assert .NoError (t , err )
116136 t .Logf ("Hashes %+v, cached servers: %+v" , state .PrefixHashes , state .PrefixCacheServers )
117137 // Input size is 8, hash block size is 4, so 2 hashes will be calculated.
@@ -121,16 +141,22 @@ func TestPrefixPlugin(t *testing.T) {
121141 assert .Equal (t , float64 (0 ), scores [pod1 ], "score for pod1" )
122142 assert .Equal (t , float64 (0 ), scores [pod2 ], "score for pod2" )
123143
124- plugin .PostCycle (context .Background (), cycleState4 , & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }})
144+ schedulingResult = & types.SchedulingResult {
145+ PrimaryProfileName : "default" ,
146+ ProfileResults : map [string ]* types.ProfileRunResult {
147+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }},
148+ },
149+ }
150+ plugin .PreRequest (context .Background (), req4 , schedulingResult , 0 )
125151
126152 // 5th request shares partial prefix with 3rd one.
127153 req5 := & types.LLMRequest {
154+ RequestId : uuid .NewString (),
128155 TargetModel : "test-model1" ,
129156 Prompt : "aaaabbbbcccc" ,
130157 }
131- cycleState5 := types .NewCycleState ()
132- scores = plugin .Score (context .Background (), cycleState5 , req5 , pods )
133- state , err = types .ReadCycleStateKey [* SchedulingContextState ](cycleState5 , PrefixCachePluginType )
158+ scores = plugin .Score (context .Background (), nil , req5 , pods )
159+ state , err = plugins .ReadPluginStateKey [* SchedulingContextState ](plugin .pluginState , req5 .RequestId , PrefixCachePluginType )
134160 assert .NoError (t , err )
135161 t .Logf ("Hashes %+v, cached servers: %+v" , state .PrefixHashes , state .PrefixCacheServers )
136162 // Input size is 12, hash block size is 4, so 3 hashes will be calculated.
@@ -140,7 +166,13 @@ func TestPrefixPlugin(t *testing.T) {
140166 assert .Equal (t , 0.75 , scores [pod1 ], "score should be 0.75 - the model and the first 2 prefix blocks match" )
141167 assert .Equal (t , float64 (0 ), scores [pod2 ], "score for pod2" )
142168
143- plugin .PostCycle (context .Background (), cycleState5 , & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }})
169+ schedulingResult = & types.SchedulingResult {
170+ PrimaryProfileName : "default" ,
171+ ProfileResults : map [string ]* types.ProfileRunResult {
172+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod1 }},
173+ },
174+ }
175+ plugin .PreRequest (context .Background (), req5 , schedulingResult , 0 )
144176}
145177
146178// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -153,7 +185,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
153185 LRUCapacityPerServer : DefaultLRUCapacityPerServer ,
154186 }
155187
156- plugin := New (config )
188+ plugin := New (context . Background (), config )
157189 types .NewCycleState ()
158190 var promptLen []int
159191 for i := 1 ; i <= 1024 ; i ++ {
@@ -174,17 +206,23 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
174206
175207 pods := []types.Pod {pod }
176208 req := & types.LLMRequest {
209+ RequestId : uuid .NewString (),
177210 TargetModel : "model-stress" ,
178211 Prompt : prompt ,
179212 }
180213
181214 // First cycle: simulate scheduling and insert prefix info into the cache
182- cycleState := types .NewCycleState ()
183- plugin .Score (context .Background (), cycleState , req , pods )
184- plugin .PostCycle (context .Background (), cycleState , & types.ProfileRunResult {TargetPods : []types.Pod {pod }})
215+ plugin .Score (context .Background (), nil , req , pods )
216+ schedulingResult := & types.SchedulingResult {
217+ PrimaryProfileName : "default" ,
218+ ProfileResults : map [string ]* types.ProfileRunResult {
219+ "default" : & types.ProfileRunResult {TargetPods : []types.Pod {pod }},
220+ },
221+ }
222+ plugin .PreRequest (context .Background (), req , schedulingResult , 0 )
185223
186224 // Second cycle: validate internal state
187- state , err := types . ReadCycleStateKey [* SchedulingContextState ](cycleState , PrefixCachePluginType )
225+ state , err := plugins . ReadPluginStateKey [* SchedulingContextState ](plugin . pluginState , req . RequestId , PrefixCachePluginType )
188226 assert .NoError (b , err )
189227 expectedHashes := int (math .Min (float64 (maxPrefixBlocks + 1 ), float64 (len (req .Prompt )/ blockSize + 1 ))) // the extra one is for the model.
190228 assert .Equal (b , expectedHashes , len (state .PrefixHashes ), "number of hashes is incorrect" )
0 commit comments