Skip to content

Commit 2f8d627

Browse files
committed
- rename LLMRequestData to LLMRequestBody
- rename LLMRequest.Data to LLMRequest.Body - test refactoring after rebase Signed-off-by: Maroon Ayoub <[email protected]>
1 parent aae9c0f commit 2f8d627

File tree

6 files changed

+47
-47
lines changed

6 files changed

+47
-47
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
103103
}
104104
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
105105

106-
requestData, err := requtil.ExtractRequestData(reqCtx.Request.Body)
106+
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
107107
if err != nil {
108108
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
109109
}
@@ -125,7 +125,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
125125
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
126126
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
127127
TargetModel: reqCtx.TargetModelName,
128-
Data: requestData,
128+
Body: requestBody,
129129
Headers: reqCtx.Request.Headers,
130130
}
131131

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
261261
// For block i, hash(i) = hash(block i content, hash(i-1)).
262262
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
263263
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
264-
if request == nil || request.Data == nil {
264+
if request == nil || request.Body == nil {
265265
loggerDebug.Info("Request or request data is nil, skipping hashing")
266266
return nil
267267
}
@@ -305,10 +305,10 @@ func toBytes(i BlockHash) []byte {
305305
}
306306

307307
func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
308-
if request.Data.Completions != nil { // assumed to be valid if not nil
309-
return []byte(request.Data.Completions.Prompt), nil
308+
if request.Body.Completions != nil { // assumed to be valid if not nil
309+
return []byte(request.Body.Completions.Prompt), nil
310310
}
311311

312312
// must be chat-completions request at this point, return bytes of entire messages
313-
return json.Marshal(request.Data.ChatCompletions.Messages)
313+
return json.Marshal(request.Body.ChatCompletions.Messages)
314314
}

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
4949
req1 := &types.LLMRequest{
5050
RequestId: uuid.NewString(),
5151
TargetModel: "test-model1",
52-
Data: &types.LLMRequestData{
52+
Body: &types.LLMRequestBody{
5353
Completions: &types.CompletionsRequest{
5454
Prompt: "aaaaaa",
5555
},
@@ -81,7 +81,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
8181
req2 := &types.LLMRequest{
8282
RequestId: uuid.NewString(),
8383
TargetModel: "test-model2",
84-
Data: &types.LLMRequestData{
84+
Body: &types.LLMRequestBody{
8585
Completions: &types.CompletionsRequest{
8686
Prompt: "bbbbbb",
8787
},
@@ -112,7 +112,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
112112
req3 := &types.LLMRequest{
113113
RequestId: uuid.NewString(),
114114
TargetModel: "test-model1",
115-
Data: &types.LLMRequestData{
115+
Body: &types.LLMRequestBody{
116116
Completions: &types.CompletionsRequest{
117117
Prompt: "aaaabbbb",
118118
},
@@ -142,7 +142,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
142142
req4 := &types.LLMRequest{
143143
RequestId: uuid.NewString(),
144144
TargetModel: "test-model-new",
145-
Data: &types.LLMRequestData{
145+
Body: &types.LLMRequestBody{
146146
Completions: &types.CompletionsRequest{
147147
Prompt: "aaaabbbb",
148148
},
@@ -172,7 +172,7 @@ func TestPrefixPluginCompletion(t *testing.T) {
172172
req5 := &types.LLMRequest{
173173
RequestId: uuid.NewString(),
174174
TargetModel: "test-model1",
175-
Data: &types.LLMRequestData{
175+
Body: &types.LLMRequestBody{
176176
Completions: &types.CompletionsRequest{
177177
Prompt: "aaaabbbbcccc",
178178
},
@@ -214,7 +214,7 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
214214
req1 := &types.LLMRequest{
215215
RequestId: uuid.NewString(),
216216
TargetModel: "test-model1",
217-
Data: &types.LLMRequestData{
217+
Body: &types.LLMRequestBody{
218218
ChatCompletions: &types.ChatCompletionsRequest{
219219
Messages: []types.Message{
220220
{Role: "user", Content: "hello world"},
@@ -223,8 +223,8 @@ func TestPrefixPluginChatCompletions(t *testing.T) {
223223
},
224224
},
225225
}
226-
scores := plugin.Score(context.Background(), nil, req1, pods)
227-
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
226+
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
227+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
228228
assert.NoError(t, err)
229229
t.Logf("Chat completions - Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
230230
// Should have some hashes for the JSON-encoded messages
@@ -249,7 +249,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
249249
req1 := &types.LLMRequest{
250250
RequestId: uuid.NewString(),
251251
TargetModel: "test-model1",
252-
Data: &types.LLMRequestData{
252+
Body: &types.LLMRequestBody{
253253
ChatCompletions: &types.ChatCompletionsRequest{
254254
Messages: []types.Message{
255255
{Role: "system", Content: "You are a helpful assistant"},
@@ -258,8 +258,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
258258
},
259259
},
260260
}
261-
scores := plugin.Score(context.Background(), nil, req1, pods)
262-
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, PrefixCachePluginType)
261+
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
262+
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
263263
assert.NoError(t, err)
264264
t.Logf("Initial conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
265265
initialHashCount := len(state.PrefixHashes)
@@ -281,7 +281,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
281281
req2 := &types.LLMRequest{
282282
RequestId: uuid.NewString(),
283283
TargetModel: "test-model1",
284-
Data: &types.LLMRequestData{
284+
Body: &types.LLMRequestBody{
285285
ChatCompletions: &types.ChatCompletionsRequest{
286286
Messages: []types.Message{
287287
{Role: "system", Content: "You are a helpful assistant"},
@@ -292,8 +292,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
292292
},
293293
},
294294
}
295-
scores = plugin.Score(context.Background(), nil, req2, pods)
296-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, PrefixCachePluginType)
295+
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
296+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
297297
assert.NoError(t, err)
298298
t.Logf("Extended conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
299299
extendedHashCount := len(state.PrefixHashes)
@@ -313,7 +313,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
313313
req3 := &types.LLMRequest{
314314
RequestId: uuid.NewString(),
315315
TargetModel: "test-model1",
316-
Data: &types.LLMRequestData{
316+
Body: &types.LLMRequestBody{
317317
ChatCompletions: &types.ChatCompletionsRequest{
318318
Messages: []types.Message{
319319
{Role: "system", Content: "You are a helpful assistant"},
@@ -326,8 +326,8 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
326326
},
327327
},
328328
}
329-
scores = plugin.Score(context.Background(), nil, req3, pods)
330-
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, PrefixCachePluginType)
329+
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
330+
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
331331
assert.NoError(t, err)
332332
t.Logf("Long conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
333333
longHashCount := len(state.PrefixHashes)
@@ -375,7 +375,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
375375
req := &types.LLMRequest{
376376
RequestId: uuid.NewString(),
377377
TargetModel: "model-stress",
378-
Data: &types.LLMRequestData{
378+
Body: &types.LLMRequestBody{
379379
Completions: &types.CompletionsRequest{
380380
Prompt: prompt,
381381
},
@@ -396,7 +396,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
396396
// Second cycle: validate internal state
397397
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
398398
assert.NoError(b, err)
399-
expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Data.Completions.Prompt)/blockSize)))
399+
expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Body.Completions.Prompt)/blockSize)))
400400
assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")
401401
}
402402
}
@@ -464,7 +464,7 @@ func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
464464
req := &types.LLMRequest{
465465
RequestId: uuid.NewString(),
466466
TargetModel: "chat-model-stress",
467-
Data: &types.LLMRequestData{
467+
Body: &types.LLMRequestBody{
468468
ChatCompletions: &types.ChatCompletionsRequest{
469469
Messages: messages,
470470
},

pkg/epp/scheduling/types/types.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type LLMRequest struct {
3232
// TargetModel is the final target model after traffic split.
3333
TargetModel string
3434
// Data contains the request-body fields that we parse out as user input.
35-
Data *LLMRequestData
35+
Body *LLMRequestBody
3636
// Headers is a map of the request headers.
3737
Headers map[string]string
3838
}
@@ -42,14 +42,14 @@ func (r *LLMRequest) String() string {
4242
return nilString
4343
}
4444

45-
return fmt.Sprintf("RequestID: %s, TargetModel: %s, RequestData: %s, Headers: %v",
46-
r.RequestId, r.TargetModel, r.Data, r.Headers)
45+
return fmt.Sprintf("RequestID: %s, TargetModel: %s, Body: %s, Headers: %v",
46+
r.RequestId, r.TargetModel, r.Body, r.Headers)
4747
}
4848

49-
// LLMRequestData contains the request-body fields that we parse out as user input,
49+
// LLMRequestBody contains the request-body fields that we parse out as user input,
5050
// to be used in forming scheduling decisions.
51-
// An LLMRequestData must contain exactly one of CompletionsRequest or ChatCompletionsRequest.
52-
type LLMRequestData struct {
51+
// An LLMRequestBody must contain exactly one of CompletionsRequest or ChatCompletionsRequest.
52+
type LLMRequestBody struct {
5353
// CompletionsRequest is the representation of the OpenAI /v1/completions request body.
5454
Completions *CompletionsRequest `json:"completions,omitempty"`
5555
// ChatCompletionsRequest is the representation of the OpenAI /v1/chat_completions request body.

pkg/epp/util/request/body.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ import (
2323
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
2424
)
2525

26-
// ExtractRequestData extracts the LLMRequestData from the given request body map.
27-
func ExtractRequestData(body map[string]any) (*types.LLMRequestData, error) {
26+
// ExtractRequestBody extracts the LLMRequestBody from the given request body map.
27+
func ExtractRequestBody(rawBody map[string]any) (*types.LLMRequestBody, error) {
2828
// Convert map back to JSON bytes
29-
jsonBytes, err := json.Marshal(body)
29+
jsonBytes, err := json.Marshal(rawBody)
3030
if err != nil {
3131
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid request body"}
3232
}
3333

3434
// Try completions request first
3535
var completions types.CompletionsRequest
3636
if err = json.Unmarshal(jsonBytes, &completions); err == nil && completions.Prompt != "" {
37-
return &types.LLMRequestData{Completions: &completions}, nil
37+
return &types.LLMRequestBody{Completions: &completions}, nil
3838
}
3939

4040
// Try chat completions
@@ -47,7 +47,7 @@ func ExtractRequestData(body map[string]any) (*types.LLMRequestData, error) {
4747
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid chat-completions request: " + err.Error()}
4848
}
4949

50-
return &types.LLMRequestData{ChatCompletions: &chatCompletions}, nil
50+
return &types.LLMRequestBody{ChatCompletions: &chatCompletions}, nil
5151
}
5252

5353
func validateChatCompletionsMessages(messages []types.Message) error {

pkg/epp/util/request/body_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestExtractRequestData(t *testing.T) {
2727
tests := []struct {
2828
name string
2929
body map[string]any
30-
want *types.LLMRequestData
30+
want *types.LLMRequestBody
3131
wantErr bool
3232
}{
3333
{
@@ -36,7 +36,7 @@ func TestExtractRequestData(t *testing.T) {
3636
"model": "test",
3737
"prompt": "test prompt",
3838
},
39-
want: &types.LLMRequestData{
39+
want: &types.LLMRequestBody{
4040
Completions: &types.CompletionsRequest{
4141
Prompt: "test prompt",
4242
},
@@ -55,7 +55,7 @@ func TestExtractRequestData(t *testing.T) {
5555
},
5656
},
5757
},
58-
want: &types.LLMRequestData{
58+
want: &types.LLMRequestBody{
5959
ChatCompletions: &types.ChatCompletionsRequest{
6060
Messages: []types.Message{
6161
{Role: "system", Content: "this is a system message"},
@@ -79,7 +79,7 @@ func TestExtractRequestData(t *testing.T) {
7979
"add_generation_prompt": true,
8080
"chat_template_kwargs": map[string]any{"key": "value"},
8181
},
82-
want: &types.LLMRequestData{
82+
want: &types.LLMRequestBody{
8383
ChatCompletions: &types.ChatCompletionsRequest{
8484
Messages: []types.Message{{Role: "user", Content: "hello"}},
8585
Tools: []any{map[string]any{"type": "function"}},
@@ -229,17 +229,17 @@ func TestExtractRequestData(t *testing.T) {
229229

230230
for _, tt := range tests {
231231
t.Run(tt.name, func(t *testing.T) {
232-
got, err := ExtractRequestData(tt.body)
232+
got, err := ExtractRequestBody(tt.body)
233233
if (err != nil) != tt.wantErr {
234-
t.Errorf("ExtractRequestData() error = %v, wantErr %v", err, tt.wantErr)
234+
t.Errorf("ExtractRequestBody() error = %v, wantErr %v", err, tt.wantErr)
235235
return
236236
}
237237
if tt.wantErr {
238238
return
239239
}
240240

241241
if diff := cmp.Diff(tt.want, got); diff != "" {
242-
t.Errorf("ExtractRequestData() mismatch (-want +got):\n%s", diff)
242+
t.Errorf("ExtractRequestBody() mismatch (-want +got):\n%s", diff)
243243
}
244244
})
245245
}
@@ -254,7 +254,7 @@ func BenchmarkExtractRequestData_Completions(b *testing.B) {
254254

255255
b.ResetTimer()
256256
for i := 0; i < b.N; i++ {
257-
_, err := ExtractRequestData(body)
257+
_, err := ExtractRequestBody(body)
258258
if err != nil {
259259
b.Fatal(err)
260260
}
@@ -271,7 +271,7 @@ func BenchmarkExtractRequestData_ChatCompletions(b *testing.B) {
271271

272272
b.ResetTimer()
273273
for i := 0; i < b.N; i++ {
274-
_, err := ExtractRequestData(body)
274+
_, err := ExtractRequestBody(body)
275275
if err != nil {
276276
b.Fatal(err)
277277
}
@@ -295,7 +295,7 @@ func BenchmarkExtractRequestData_ChatCompletionsWithOptionals(b *testing.B) {
295295

296296
b.ResetTimer()
297297
for i := 0; i < b.N; i++ {
298-
_, err := ExtractRequestData(body)
298+
_, err := ExtractRequestBody(body)
299299
if err != nil {
300300
b.Fatal(err)
301301
}

0 commit comments

Comments
 (0)