diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 5330dd278..a3e2d6d13 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -103,10 +103,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo } reqCtx.Request.Body["model"] = reqCtx.TargetModelName - prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap) + requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body) if err != nil { - return reqCtx, err + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()} } + infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey) if infObjective == nil { logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey) @@ -124,7 +125,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], TargetModel: reqCtx.TargetModelName, - Prompt: prompt, + Body: requestBody, Headers: reqCtx.Request.Headers, } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 49ec7fa44..40e88062a 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -125,8 +125,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData { } // compile-time type assertion -var _ framework.Scorer = &Plugin{} -var _ requestcontrol.PreRequest = &Plugin{} +var ( + _ framework.Scorer = &Plugin{} + _ requestcontrol.PreRequest = &Plugin{} +) // PrefixCachePluginFactory defines the factory function for Prefix plugin. func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { @@ -248,7 +250,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map for server := range cachedServers { // Update servers with their longest prefix match. res[server]++ - } } } @@ -260,33 +261,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map // For block i, hash(i) = hash(block i content, hash(i-1)). func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) - prompt := []byte(request.Prompt) - if len(prompt) < cacheBlockSize { - loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize) + if request == nil || request.Body == nil { + loggerDebug.Info("Request or request data is nil, skipping hashing") return nil } - if len(prompt) > cacheBlockSize*maxPrefixBlocks { - loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) - prompt = prompt[:maxPrefixBlocks*cacheBlockSize] + + userInput, err := getUserInputBytes(request) + if err != nil { + loggerDebug.Error(err, "Failed to get user input bytes") + return nil } - // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model. - // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize) - // Add the model to the first block hash so that different models have different hashes even with the same body. - firstBlockSize := cacheBlockSize - if len(prompt) < cacheBlockSize { - firstBlockSize = len(prompt) + if len(userInput) < cacheBlockSize { + loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) + return nil } - firstBlock := prompt[0:firstBlockSize] - firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...) - res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel))) - - for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { - block := prompt[i : i+cacheBlockSize] - prevBlockHash := res[len(res)-1] - block = append(block, toBytes(prevBlockHash)...) - res = append(res, BlockHash(xxhash.Sum64(block))) + if len(userInput) > cacheBlockSize*maxPrefixBlocks { + loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + userInput = userInput[:maxPrefixBlocks*cacheBlockSize] + } + // Split the body into blocks of size cacheBlockSize. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) + // Add the model to the first block hash so that different models have different hashes even with the same body. + h := xxhash.New() + _, _ = h.Write([]byte(request.TargetModel)) + prevBlockHash := BlockHash(h.Sum64()) + for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { + h.Reset() + _, _ = h.Write(userInput[i : i+cacheBlockSize]) + _, _ = h.Write(toBytes(prevBlockHash)) + res = append(res, BlockHash(h.Sum64())) + + prevBlockHash = res[len(res)-1] } return res } @@ -296,3 +303,12 @@ func toBytes(i BlockHash) []byte { binary.LittleEndian.PutUint64(bytes, uint64(i)) return bytes } + +func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { + if request.Body.Completions != nil { // assumed to be valid if not nil + return []byte(request.Body.Completions.Prompt), nil + } + + // must be chat-completions request at this point, return bytes of entire messages + return json.Marshal(request.Body.ChatCompletions.Messages) +} diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 3fbac2ce1..9f9893ba8 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -33,8 +33,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -func TestPrefixPlugin(t *testing.T) { - +func TestPrefixPluginCompletion(t *testing.T) { config := Config{ HashBlockSize: 4, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, @@ -50,7 +49,11 @@ func TestPrefixPlugin(t *testing.T) { req1 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", - Prompt: "aaaaaa", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaa", + }, + }, } scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -78,7 +81,11 @@ func TestPrefixPlugin(t *testing.T) { req2 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model2", - Prompt: "bbbbbb", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "bbbbbb", + }, + }, } scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -105,7 +112,11 @@ func TestPrefixPlugin(t *testing.T) { req3 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", - Prompt: "aaaabbbb", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaabbbb", + }, + }, } scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -131,7 +142,11 @@ func TestPrefixPlugin(t *testing.T) { req4 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model-new", - Prompt: "aaaabbbb", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaabbbb", + }, + }, } scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -157,7 +172,11 @@ func TestPrefixPlugin(t *testing.T) { req5 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", - Prompt: "aaaabbbbcccc", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaabbbbcccc", + }, + }, } scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods) state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String())) @@ -180,6 +199,149 @@ func TestPrefixPlugin(t *testing.T) { plugin.wg.Wait() } +func TestPrefixPluginChatCompletions(t *testing.T) { + config := Config{ + HashBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pods := []types.Pod{pod1} + + // Test with chat completions request + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "user", Content: "hello world"}, + {Role: "assistant", Content: "hi there"}, + }, + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Chat completions - Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Should have some hashes for the JSON-encoded messages + assert.Greater(t, len(state.PrefixHashes), 1, "should have hashes for chat completions") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially") + assert.Equal(t, float64(0), scores[pod1], "score for pod1") +} + +func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { + config := Config{ + HashBlockSize: 8, // Use larger block size for more predictable JSON marshaling + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // First request with initial conversation + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "system", Content: "You are a helpful assistant"}, + {Role: "user", Content: "Hello, how are you?"}, + }, + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Initial conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) + initialHashCount := len(state.PrefixHashes) + assert.Greater(t, initialHashCount, 1, "should have hashes for chat completions") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially") + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod1 was picked + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + + // Second request adds assistant response and new user message (conversation grows) + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "system", Content: "You are a helpful assistant"}, + {Role: "user", Content: "Hello, how are you?"}, + {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + {Role: "user", Content: "Can you explain how prefix caching works?"}, + }, + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Extended conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) + extendedHashCount := len(state.PrefixHashes) + assert.Greater(t, extendedHashCount, initialHashCount, "extended conversation should have more hashes") + assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match") + + // Calculate expected score - pod1 should have cached the initial prefix + cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] + expectedScore := float64(cachedBlocks) / float64(extendedHashCount) + assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") + assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") + + // Simulate pod1 was picked again + plugin.PreRequest(context.Background(), req2, schedulingResult, 0) + + // Third request continues the conversation even further + req3 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "system", Content: "You are a helpful assistant"}, + {Role: "user", Content: "Hello, how are you?"}, + {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + {Role: "user", Content: "Can you explain how prefix caching works?"}, + {Role: "assistant", Content: "Prefix caching is a technique where..."}, + {Role: "user", Content: "That's very helpful, thank you!"}, + }, + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Long conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers) + longHashCount := len(state.PrefixHashes) + assert.Greater(t, longHashCount, extendedHashCount, "long conversation should have even more hashes") + assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match") + + // pod1 should have an even higher cache hit rate now + cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] + expectedScore = float64(cachedBlocks) / float64(longHashCount) + assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") + assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") +} + // TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length. func BenchmarkPrefixPluginStress(b *testing.B) { blockSize := 4 @@ -213,7 +375,11 @@ func BenchmarkPrefixPluginStress(b *testing.B) { req := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "model-stress", - Prompt: prompt, + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: prompt, + }, + }, } // First cycle: simulate scheduling and insert prefix info into the cache @@ -230,7 +396,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) { // Second cycle: validate internal state state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String())) assert.NoError(b, err) - expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize))) + expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Body.Completions.Prompt)/blockSize))) assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect") } } @@ -244,3 +410,76 @@ func randomPrompt(n int) string { } return sb.String() } + +// BenchmarkPrefixPluginChatCompletionsStress is a stress test for chat completions with varying message counts and lengths +func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) { + blockSize := 8 + maxPrefixBlocks := 50000 + config := Config{ + HashBlockSize: blockSize, + MaxPrefixBlocksToMatch: maxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + + plugin := New(context.Background(), config) + + // Test scenarios: varying number of messages and message lengths + scenarios := []struct { + messageCount int + messageLength int + }{ + {2, 50}, // Short conversation, short messages + {2, 500}, // Short conversation, long messages + {5, 100}, // Medium conversation, medium messages + {10, 200}, // Long conversation, medium messages + {20, 100}, // Very long conversation, medium messages + {50, 50}, // Very long conversation, short messages + {100, 25}, // Extremely long conversation, very short messages + } + + for _, scenario := range scenarios { + b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) { + // Generate messages for this scenario + messages := make([]types.Message, scenario.messageCount) + messages[0] = types.Message{Role: "system", Content: "You are a helpful assistant."} + + for i := 1; i < scenario.messageCount; i++ { + role := "user" + if i%2 == 0 { + role = "assistant" + } + content := randomPrompt(scenario.messageLength) + messages[i] = types.Message{Role: role, Content: content} + } + + pod := &types.PodMetrics{ + Pod: &backend.Pod{ + NamespacedName: k8stypes.NamespacedName{ + Name: fmt.Sprintf("chat-pod-%d-%d", scenario.messageCount, scenario.messageLength), + }, + }, + } + pods := []types.Pod{pod} + + req := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "chat-model-stress", + Body: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: messages, + }, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Benchmark the scoring operation + scores := plugin.Score(context.Background(), nil, req, pods) + _ = scores // Use the result to prevent optimization + + // Clean up state for next iteration + plugin.pluginState.Delete(req.RequestId) + } + }) + } +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 296211759..2685a22d0 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -23,20 +23,90 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) +const nilString = "" + // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { // RequestId is the Envoy generated Id for the request being processed RequestId string // TargetModel is the final target model after traffic split. TargetModel string - // Prompt is the prompt that was sent in the request body. - Prompt string + // Data contains the request-body fields that we parse out as user input. + Body *LLMRequestBody // Headers is a map of the request headers. Headers map[string]string } func (r *LLMRequest) String() string { - return fmt.Sprintf("RequestID: %s, TargetModel: %s, PromptLength: %d, Headers: %v", r.RequestId, r.TargetModel, len(r.Prompt), r.Headers) + if r == nil { + return nilString + } + + return fmt.Sprintf("RequestID: %s, TargetModel: %s, Body: %s, Headers: %v", + r.RequestId, r.TargetModel, r.Body, r.Headers) +} + +// LLMRequestBody contains the request-body fields that we parse out as user input, +// to be used in forming scheduling decisions. +// An LLMRequestBody must contain exactly one of CompletionsRequest or ChatCompletionsRequest. +type LLMRequestBody struct { + // CompletionsRequest is the representation of the OpenAI /v1/completions request body. + Completions *CompletionsRequest `json:"completions,omitempty"` + // ChatCompletionsRequest is the representation of the OpenAI /v1/chat_completions request body. + ChatCompletions *ChatCompletionsRequest `json:"chat_completions,omitempty"` +} + +// CompletionsRequest is a structured representation of the fields we parse out of the +// /v1/completions request body. +// This struct includes fields usable for plugins and scheduling decisions - and not the entire +// API spec. +type CompletionsRequest struct { + // Prompt is the prompt that was sent in the request body. + Prompt string `json:"prompt,omitempty"` +} + +func (r *CompletionsRequest) String() string { + if r == nil { + return nilString + } + + return fmt.Sprintf("{PromptLength: %d}", len(r.Prompt)) +} + +// ChatCompletionsRequest is a structured representation of the fields we parse out of the +// /v1/chat/completions request body. +// This struct includes fields usable for plugins and scheduling decisions - and not the entire +// API spec. +type ChatCompletionsRequest struct { + /* parameters from the official OpenAI chat-completions API */ + Messages []Message `json:"messages,omitempty"` + Tools []interface{} `json:"tools,omitempty"` + /* parameters from the HuggingFace transformers chat-templates API */ + Documents []interface{} `json:"documents,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + ReturnAssistantTokensMask bool `json:"return_assistant_tokens_mask,omitempty"` + ContinueFinalMessage bool `json:"continue_final_message,omitempty"` + AddGenerationPrompt bool `json:"add_generation_prompt,omitempty"` + ChatTemplateKWArgs map[string]interface{} `json:"chat_template_kwargs,omitempty"` +} + +func (r *ChatCompletionsRequest) String() string { + if r == nil { + return nilString + } + + messagesLen := 0 + for _, msg := range r.Messages { + messagesLen += len(msg.Content) + } + + return fmt.Sprintf("{MessagesLength: %d}", messagesLen) +} + +// Message represents a single message in a chat-completions request. +type Message struct { + Role string + Content string // TODO: support multi-modal content } type Pod interface { @@ -52,8 +122,9 @@ type ScoredPod struct { func (pm *PodMetrics) String() string { if pm == nil { - return "" + return nilString } + return fmt.Sprintf("%+v", *pm) } diff --git a/pkg/epp/util/request/body.go b/pkg/epp/util/request/body.go index 46de1fa54..07877415f 100644 --- a/pkg/epp/util/request/body.go +++ b/pkg/epp/util/request/body.go @@ -17,70 +17,43 @@ limitations under the License. package request import ( - "fmt" + "encoding/json" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" ) -func ExtractPromptFromRequestBody(body map[string]any) (string, error) { - if _, ok := body["messages"]; ok { - return extractPromptFromMessagesField(body) +// ExtractRequestBody extracts the LLMRequestBody from the given request body map. +func ExtractRequestBody(rawBody map[string]any) (*types.LLMRequestBody, error) { + // Convert map back to JSON bytes + jsonBytes, err := json.Marshal(rawBody) + if err != nil { + return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid request body"} } - return extractPromptField(body) -} -func extractPromptField(body map[string]any) (string, error) { - prompt, ok := body["prompt"] - if !ok { - return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} - } - promptStr, ok := prompt.(string) - if !ok { - return "", errutil.Error{Code: errutil.BadRequest, Msg: "prompt is not a string"} + // Try completions request first + var completions types.CompletionsRequest + if err = json.Unmarshal(jsonBytes, &completions); err == nil && completions.Prompt != "" { + return &types.LLMRequestBody{Completions: &completions}, nil } - return promptStr, nil -} -func extractPromptFromMessagesField(body map[string]any) (string, error) { - messages, ok := body["messages"] - if !ok { - return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages not found in request"} - } - messageList, ok := messages.([]any) - if !ok { - return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is not a list"} - } - if len(messageList) == 0 { - return "", errutil.Error{Code: errutil.BadRequest, Msg: "messages is empty"} + // Try chat completions + var chatCompletions types.ChatCompletionsRequest + if err = json.Unmarshal(jsonBytes, &chatCompletions); err != nil { + return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid request format"} } - prompt := "" - for _, msg := range messageList { - msgMap, ok := msg.(map[string]any) - if !ok { - continue - } - content, ok := msgMap["content"] - if !ok { - continue - } - contentStr, ok := content.(string) - if !ok { - continue - } - role, ok := msgMap["role"] - if !ok { - continue - } - roleStr, ok := role.(string) - if !ok { - continue - } - prompt += constructChatMessage(roleStr, contentStr) + if err = validateChatCompletionsMessages(chatCompletions.Messages); err != nil { + return nil, errutil.Error{Code: errutil.BadRequest, Msg: "invalid chat-completions request: " + err.Error()} } - return prompt, nil + + return &types.LLMRequestBody{ChatCompletions: &chatCompletions}, nil } -func constructChatMessage(role string, content string) string { - return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>\n", role, content) +func validateChatCompletionsMessages(messages []types.Message) error { + if len(messages) == 0 { + return errutil.Error{Code: errutil.BadRequest, Msg: "chat-completions request must have at least one message"} + } + + return nil } diff --git a/pkg/epp/util/request/body_test.go b/pkg/epp/util/request/body_test.go index ce5a93921..64ab6de11 100644 --- a/pkg/epp/util/request/body_test.go +++ b/pkg/epp/util/request/body_test.go @@ -18,16 +18,30 @@ package request import ( "testing" + + "github.com/google/go-cmp/cmp" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -func TestExtractPromptFromRequestBody(t *testing.T) { +func TestExtractRequestData(t *testing.T) { tests := []struct { name string body map[string]any - want string + want *types.LLMRequestBody wantErr bool - errType error }{ + { + name: "completions request body", + body: map[string]any{ + "model": "test", + "prompt": "test prompt", + }, + want: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "test prompt", + }, + }, + }, { name: "chat completions request body", body: map[string]any{ @@ -39,137 +53,175 @@ func TestExtractPromptFromRequestBody(t *testing.T) { map[string]any{ "role": "user", "content": "hello", }, - map[string]any{ - "role": "assistant", "content": "hi, what can I do for you?", + }, + }, + want: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{ + {Role: "system", Content: "this is a system message"}, + {Role: "user", Content: "hello"}, }, }, }, - want: "<|im_start|>system\nthis is a system message<|im_end|>\n" + - "<|im_start|>user\nhello<|im_end|>\n" + - "<|im_start|>assistant\nhi, what can I do for you?<|im_end|>\n", }, { - name: "completions request body", + name: "chat completions with all optional fields", body: map[string]any{ - "model": "test", - "prompt": "test prompt", + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{map[string]any{"type": "function"}}, + "documents": []any{map[string]any{"content": "doc"}}, + "chat_template": "custom template", + "return_assistant_tokens_mask": true, + "continue_final_message": true, + "add_generation_prompt": true, + "chat_template_kwargs": map[string]any{"key": "value"}, + }, + want: &types.LLMRequestBody{ + ChatCompletions: &types.ChatCompletionsRequest{ + Messages: []types.Message{{Role: "user", Content: "hello"}}, + Tools: []any{map[string]any{"type": "function"}}, + Documents: []any{map[string]any{"content": "doc"}}, + ChatTemplate: "custom template", + ReturnAssistantTokensMask: true, + ContinueFinalMessage: true, + AddGenerationPrompt: true, + ChatTemplateKWArgs: map[string]any{"key": "value"}, + }, }, - want: "test prompt", + }, + { + name: "nil body", + body: nil, + wantErr: true, }, { name: "invalid prompt format", + body: map[string]any{ + "model": "test", + "prompt": 123, + }, + wantErr: true, + }, + { + name: "invalid messages format", + body: map[string]any{ + "model": "test", + "messages": "invalid", + }, + wantErr: true, + }, + { + name: "neither prompt nor messages", body: map[string]any{ "model": "test", - "prompt": []any{ - map[string]any{ - "role": "system", "content": "this is a system message", - }, - map[string]any{ - "role": "user", "content": "hello", - }, - map[string]any{ - "role": "assistant", "content": "hi, what can I", - }, + }, + wantErr: true, + }, + { + name: "empty messages array", + body: map[string]any{ + "model": "test", + "messages": []any{}, + }, + wantErr: true, + }, + { + name: "message with non-string role", + body: map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{"role": 123, "content": "hello"}, }, }, wantErr: true, }, { - name: "invalid messaged format", + name: "message with non-string content", body: map[string]any{ "model": "test", - "messages": map[string]any{ - "role": "system", "content": "this is a system message", + "messages": []any{ + map[string]any{"role": "user", "content": 123}, }, }, wantErr: true, }, { - name: "prompt does not exist", + name: "invalid tools format", body: map[string]any{ "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": "invalid", }, wantErr: true, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ExtractPromptFromRequestBody(tt.body) - if (err != nil) != tt.wantErr { - t.Errorf("ExtractPromptFromRequestBody() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("ExtractPromptFromRequestBody() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestExtractPromptField(t *testing.T) { - tests := []struct { - name string - body map[string]any - want string - wantErr bool - }{ { - name: "valid prompt", + name: "invalid documents format", body: map[string]any{ - "prompt": "test prompt", + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "documents": "invalid", }, - want: "test prompt", + wantErr: true, }, { - name: "prompt not found", - body: map[string]any{}, + name: "invalid chat_template format", + body: map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "chat_template": 123, + }, wantErr: true, }, { - name: "non-string prompt", + name: "invalid return_assistant_tokens_mask format", body: map[string]any{ - "prompt": 123, + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "return_assistant_tokens_mask": "invalid", }, wantErr: true, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := extractPromptField(tt.body) - if (err != nil) != tt.wantErr { - t.Errorf("extractPromptField() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("extractPromptField() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestExtractPromptFromMessagesField(t *testing.T) { - tests := []struct { - name string - body map[string]any - want string - wantErr bool - }{ { - name: "valid messages", + name: "invalid continue_final_message format", body: map[string]any{ + "model": "test", "messages": []any{ - map[string]any{"role": "user", "content": "test1"}, - map[string]any{"role": "assistant", "content": "test2"}, + map[string]any{"role": "user", "content": "hello"}, }, + "continue_final_message": "invalid", }, - want: "<|im_start|>user\ntest1<|im_end|>\n<|im_start|>assistant\ntest2<|im_end|>\n", + wantErr: true, }, { - name: "invalid messages format", + name: "invalid add_generation_prompt format", body: map[string]any{ - "messages": "invalid", + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "add_generation_prompt": "invalid", + }, + wantErr: true, + }, + { + name: "invalid chat_template_kwargs format", + body: map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "chat_template_kwargs": "invalid", }, wantErr: true, }, @@ -177,31 +229,75 @@ func TestExtractPromptFromMessagesField(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := extractPromptFromMessagesField(tt.body) + got, err := ExtractRequestBody(tt.body) if (err != nil) != tt.wantErr { - t.Errorf("extractPromptFromMessagesField() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ExtractRequestBody() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { return } - if got != tt.want { - t.Errorf("extractPromptFromMessagesField() got = %v, want %v", got, tt.want) + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("ExtractRequestBody() mismatch (-want +got):\n%s", diff) } }) } } -func TestConstructChatMessage(t *testing.T) { - tests := []struct { - role string - content string - want string - }{ - {"user", "hello", "<|im_start|>user\nhello<|im_end|>\n"}, - {"assistant", "hi", "<|im_start|>assistant\nhi<|im_end|>\n"}, +// Benchmark tests for performance comparison +func BenchmarkExtractRequestData_Completions(b *testing.B) { + body := map[string]any{ + "model": "test", + "prompt": "test prompt", } - for _, tt := range tests { - if got := constructChatMessage(tt.role, tt.content); got != tt.want { - t.Errorf("constructChatMessage() = %v, want %v", got, tt.want) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ExtractRequestBody(body) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkExtractRequestData_ChatCompletions(b *testing.B) { + body := map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ExtractRequestBody(body) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkExtractRequestData_ChatCompletionsWithOptionals(b *testing.B) { + body := map[string]any{ + "model": "test", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{map[string]any{"type": "function"}}, + "documents": []any{map[string]any{"content": "doc"}}, + "chat_template": "custom template", + "return_assistant_tokens_mask": true, + "continue_final_message": true, + "add_generation_prompt": true, + "chat_template_kwargs": map[string]any{"key": "value"}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ExtractRequestBody(body) + if err != nil { + b.Fatal(err) } } }