diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 001fdc344..27be31c31 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -115,6 +115,7 @@ type Request struct { } type Response struct { Headers map[string]string + Body []byte } type StreamRequestState int @@ -302,6 +303,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) break } + reqCtx.Response.Body = body reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) if responseErr != nil { if logger.V(logutil.DEBUG).Enabled() { diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index d7de39d4a..b91053d0f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -289,14 +289,15 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand // HandleResponseBodyComplete is called when the response body is fully received. func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") + requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey] + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID) logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") - response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], - Headers: reqCtx.Response.Headers, + llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body) + if err != nil { + logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") + return reqCtx, err } - - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") return reqCtx, nil @@ -346,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch } } -func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseCompletePlugins { loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index ffd62da36..896f14aa2 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -661,6 +661,27 @@ func TestDirector_HandleResponseComplete(t *testing.T) { mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1)) + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON)) + if err != nil { + t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err) + } + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ Headers: map[string]string{ @@ -669,24 +690,22 @@ func TestDirector_HandleResponseComplete(t *testing.T) { }, Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, + Body: []byte(chatCompletionJSON), }, TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyComplete(ctx, reqCtx) + _, err = director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err) } - if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" { - t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff) - } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" { + t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) + } } const ( @@ -709,7 +728,7 @@ type testResponseStreaming struct { type testResponseComplete struct { tn plugins.TypedName - lastRespOnComplete *Response + lastRespOnComplete *schedulingtypes.LLMResponse lastTargetPodOnComplete string } @@ -753,7 +772,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() } -func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { p.lastRespOnComplete = response p.lastTargetPodOnComplete = targetPod.NamespacedName.String() } diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 44334c68f..1c7eefdac 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -55,5 +55,5 @@ type ResponseStreaming interface { // ResponseComplete is called by the director after the complete response is sent. type ResponseComplete interface { plugins.Plugin - ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index eb45edeab..fa654f73b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -17,6 +17,7 @@ limitations under the License. package prefix import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -28,6 +29,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" @@ -117,6 +119,12 @@ var _ plugins.StateData = &SchedulingContextState{} type SchedulingContextState struct { // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. PrefixHashes []BlockHash + // RestBytes is the trailing bytes that not able to fill in a full block and left over. + // If not empty, this will be used as the starting block for the following response that will + // be added to the response as well. This happens especially at the multi-turn scenario. + RestBytes []byte + // BlockSize is the block size used to caculate the hash of the request/response. + BlockSize int // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -192,10 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + blockSize := getBlockSize(pods, p.config.DefaultBlockSize) // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, + RestBytes: restBytes, + BlockSize: blockSize, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -226,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) - p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return @@ -244,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] - - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) - metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) + metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. @@ -301,47 +309,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). -func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { +// Also return the extra string. +func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) if request == nil || request.Body == nil { loggerDebug.Info("Request or request data is nil, skipping hashing") - return nil + return nil, nil } userInput, err := getUserInputBytes(request) if err != nil { loggerDebug.Error(err, "Failed to get user input bytes") - return nil + return nil, nil } + prevBlockHash := defaultPrevBlock(request) + return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks) +} - if len(userInput) < cacheBlockSize { - loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) - return nil - } - if len(userInput) > cacheBlockSize*maxPrefixBlocks { - loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) - userInput = userInput[:maxPrefixBlocks*cacheBlockSize] - } - // Split the body into blocks of size cacheBlockSize. - // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) - // Add the model to the first block hash so that different models have different hashes even with the same body. +func defaultPrevBlock(request *types.LLMRequest) BlockHash { h := xxhash.New() + // Add the model to the first block hash so that different models have different hashes even with the same body. _, _ = h.Write([]byte(request.TargetModel)) if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" { _, _ = h.Write([]byte(cacheSalt)) } - prevBlockHash := BlockHash(h.Sum64()) - for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { + return BlockHash(h.Sum64()) +} + +func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + if len(input)+prevBlockLength < cacheBlockSize { + loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize) + return nil, input + } + if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks { + loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)] + } + // Split the body into blocks of size cacheBlockSize. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, len(input)/cacheBlockSize) + lastOffSet := 0 + h := xxhash.New() + for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize { h.Reset() - _, _ = h.Write(userInput[i : i+cacheBlockSize]) + _, _ = h.Write(input[i : i+cacheBlockSize]) _, _ = h.Write(toBytes(prevBlockHash)) res = append(res, BlockHash(h.Sum64())) prevBlockHash = res[len(res)-1] + lastOffSet = i + cacheBlockSize } - return res + return res, input[lastOffSet:] } func toBytes(i BlockHash) []byte { @@ -356,7 +376,39 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...) +} + +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) { + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + if err != nil { + log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) + return + } + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + + reponseForKVCache, err := response.FirstChoiceContent() + if err != nil { + log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId) + return + } + var input bytes.Buffer + input.Write(state.RestBytes) + input.Write(reponseForKVCache) + + server := ServerID(targetPod.NamespacedName) + prevBlockHash := defaultPrevBlock(request) + prevBlockHashLength := 0 + if len(state.PrefixHashes) > 0 { + prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] + prevBlockHashLength = len(state.PrefixHashes) + } + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch) + p.wg.Add(1) + go func() { + p.indexer.Add(hashBlocks, server) + p.wg.Done() + }() } func getBlockSize(pods []types.Pod, defaultBlockSize int) int { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 176393783..0b83af5f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -199,6 +199,105 @@ func TestPrefixPluginCompletion(t *testing.T) { plugin.wg.Wait() } +func TestPrefixPluginCompletionWithResponse(t *testing.T) { + const defaultBlockSize = 4 + config := Config{ + DefaultBlockSize: defaultBlockSize, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // -- First Request -- + // This initial request will populate the cache. + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaa", + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, so the last 2 characters are ignored. + // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + // The last 2 characters are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") + assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") + + // Simulate that the scheduler picked pod1 for the first request. + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.wg.Wait() + + // -- Simulate Response Completion -- + // The ResponseComplete hook is called. The plugin should update pod1's KV cache + // with the full context of the completed interaction (prompt + response). + // - Initial Prompt: "aaaaaa" + // - Response Body: "bb" + // - Cached Sequence: "aaaaaabb" (length 8) + // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". + resp1 := &types.LLMResponse{ + Completion: &types.CompletionResponse{ + Choices: []types.CompletionChoice{ + { + Text: "bb", + }, + }, + }, + } + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + + // -- Second Request: Multi-turn Follow-up -- + // This request simulates a follow-up message in a chat. The prompt contains the + // entire conversation history ("aaaaaabb") plus new text ("cc"). + // The plugin should find that the first two blocks ("aaaa", "aabb") of this new + // prompt are already cached on pod1, giving it a perfect match score of 1.0. + // Pod2 has no matching cache entries and should score 0. + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaabbcc", + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes. + // The last 2 characters ("cc") are ignored. + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // It should find a server (pod1) that has cached the prefixes. + assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The last 2 characters ("cc") are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") + // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. + assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") +} + func TestPrefixPluginChatCompletions(t *testing.T) { config := Config{ DefaultBlockSize: 4, @@ -278,6 +377,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { plugin.PreRequest(context.Background(), req1, schedulingResult, 0) plugin.wg.Wait() + resp1 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + // Second request adds assistant response and new user message (conversation grows) req2 := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -305,13 +417,27 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore := float64(cachedBlocks) / float64(extendedHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") + assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5") assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult, 0) plugin.wg.Wait() - // Third request continues the conversation even further + resp2 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "Prefix caching is a technique where..."}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod()) + plugin.wg.Wait() + + // Third request is the whole above conversation to make the cache hit to 1.0. req3 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", @@ -323,7 +449,6 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, {Role: "user", Content: "Can you explain how prefix caching works?"}, {Role: "assistant", Content: "Prefix caching is a technique where..."}, - {Role: "user", Content: "That's very helpful, thank you!"}, }, }, }, @@ -340,7 +465,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore = float64(cachedBlocks) / float64(longHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") - assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Equal(t, scores[pod1], float64(1), "cache hit rate should be substantial for growing conversation") assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") } diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go new file mode 100644 index 000000000..f6985847a --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -0,0 +1,136 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "encoding/json" + "errors" + "fmt" +) + +// LLMResponse is a structured representation of a parsed LLM response body. +// An LLMResponse must contain exactly one of ChatCompletion or LegacyCompletion. +type LLMResponse struct { + // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. + ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` + // Completion is the representation of the OpenAI /v1/completions response body. + Completion *CompletionResponse `json:"legacy_completion,omitempty"` +} + +// FirstChoiceContent extracts the first choice of the response. +func (res *LLMResponse) FirstChoiceContent() ([]byte, error) { + if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { + return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message) + } + if res.Completion != nil && len(res.Completion.Choices) > 0 { + return []byte(res.Completion.Choices[0].Text), nil + } + return nil, errors.New("no choices found in the LLM response") +} + +// ChatCompletionResponse represents the full response body for the chat completions API. +type ChatCompletionResponse struct { + Choices []ChatChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *ChatCompletionResponse) String() string { + if r == nil { + return nilString + } + contentLen := 0 + if len(r.Choices) > 0 { + contentLen = len(r.Choices[0].Message.Content) + } + return fmt.Sprintf("{ContentLength: %d, Usage: %s}", contentLen, r.Usage) +} + +// ChatChoice represents a single choice in the chat completion response. +type ChatChoice struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatMessage represents the message object within a choice. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// CompletionResponse represents the full response body for the legacy completions API. +type CompletionResponse struct { + Choices []CompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *CompletionResponse) String() string { + if r == nil { + return nilString + } + textLen := 0 + if len(r.Choices) > 0 { + textLen = len(r.Choices[0].Text) + } + return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) +} + +// CompletionChoice represents a single choice in the legacy completion response. +type CompletionChoice struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +// Usage represents the token usage data common to all response formats. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (u *Usage) String() string { + if u == nil { + return nilString + } + return fmt.Sprintf("{Prompt: %d, Completion: %d, Total: %d}", u.PromptTokens, u.CompletionTokens, u.TotalTokens) +} + +// NewLLMResponseFromBytes initializes an LLMResponse by trying to parse the data +// as a chat completion and then as a legacy completion response. +func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, errors.New("input bytes are empty") + } + + // Attempt to unmarshal as a ChatCompletionResponse first. + var chatResp ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err == nil { + // Check if the role is set to distinguish ChatCompletion and LegacyCompletion. + if len(chatResp.Choices) > 0 && chatResp.Choices[0].Message.Role != "" { + return &LLMResponse{ChatCompletion: &chatResp}, nil + } + } + + // Try to unmarshal as a LegacyCompletionResponse. + var legacyResp CompletionResponse + if err := json.Unmarshal(body, &legacyResp); err == nil { + if len(legacyResp.Choices) > 0 { + return &LLMResponse{Completion: &legacyResp}, nil + } + } + + return nil, errors.New("failed to unmarshal body into any known LLM response format") +} diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go new file mode 100644 index 000000000..759c9caac --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -0,0 +1,403 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNewLLMResponseFromBytes(t *testing.T) { + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyUsageJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ] + }` + + legacyCompletionEmptyUsageJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ] + }` + + invalidJSON := `{"invalid": json}` + unstructuredJSON := `{"foo": "bar"}` + + testCases := []struct { + name string + input []byte + want *LLMResponse + wantError bool + }{ + { + name: "valid chat completion response", + input: []byte(chatCompletionJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 1, + CompletionTokens: 2, + TotalTokens: 3, + }, + }, + }, + wantError: false, + }, + { + name: "valid legacy completion response", + input: []byte(legacyCompletionJSON), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 4, + CompletionTokens: 5, + TotalTokens: 9, + }, + }, + }, + wantError: false, + }, + { + name: "invalid json", + input: []byte(invalidJSON), + want: nil, + wantError: true, + }, + { + name: "empty input", + input: []byte{}, + want: nil, + wantError: true, + }, + { + name: "unstructured json", + input: []byte(unstructuredJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty choices", + input: []byte(chatCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "legacy completion with empty choices", + input: []byte(legacyCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty usage", + input: []byte(chatCompletionEmptyUsageJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + { + name: "legacy completion with empty usage", + input: []byte(legacyCompletionEmptyUsageJSON), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromBytes(tc.input) + + if (err != nil) != tc.wantError { + t.Errorf("NewLLMResponseFromBytes() error = %v, wantError %v", err, tc.wantError) + return + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromBytes() (-want +got): %v", diff) + } + }) + } +} + +func TestFirstChoiceContent(t *testing.T) { + testCases := []struct { + name string + res *LLMResponse + want []byte + wantError bool + }{ + { + name: "chatCompletion with choice", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: Message{Role: "assistant", Content: "Hello from Chat"}}, + }, + }, + }, + want: []byte(`{"Role":"assistant","Content":"Hello from Chat"},`), + }, + { + name: "legacyCompletion with choice", + res: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + {Text: "Hello from Legacy"}, + }, + }, + }, + want: []byte(`Hello from Legacy`), + }, + { + name: "chatCompletion with no choices", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{}, + }, + }, + wantError: true, + }, + { + name: "legacyCompletion with no choices", + res: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{}, + }, + }, + wantError: true, + }, + { + name: "LLMResponse with all fields nil", + res: &LLMResponse{ + ChatCompletion: nil, + Completion: nil, + }, + wantError: true, + }, + { + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.res.FirstChoiceContent() + if tc.wantError != (err != nil) { + t.Errorf("FirstChoiceContent() wantError is %v, but got error: %v", tc.wantError, err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("FirstChoiceContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUsage_String(t *testing.T) { + var nilUsage *Usage + tests := []struct { + name string + u *Usage + want string + }{ + { + name: "nil usage", + u: nilUsage, + want: nilString, + }, + { + name: "non-nil usage", + u: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + want: "{Prompt: 1, Completion: 2, Total: 3}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("Usage.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChatCompletionResponse_String(t *testing.T) { + var nilResp *ChatCompletionResponse + tests := []struct { + name string + r *ChatCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &ChatCompletionResponse{Choices: []ChatChoice{}, Usage: &Usage{}}, + want: "{ContentLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: Message{Content: "hello"}}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{ContentLength: 5, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("ChatCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLegacyCompletionResponse_String(t *testing.T) { + var nilResp *CompletionResponse + tests := []struct { + name string + r *CompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &CompletionResponse{Choices: []CompletionChoice{}, Usage: &Usage{}}, + want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &CompletionResponse{ + Choices: []CompletionChoice{ + {Text: "hello world"}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{TextLength: 11, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("LegacyCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index c93f0c5ac..18c2a6831 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -17,13 +17,18 @@ limitations under the License. package types import ( + "bytes" + "encoding/json" "fmt" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) -const nilString = "" +const ( + nilString = "" + messageSplit = "," +) // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { @@ -125,6 +130,25 @@ type Message struct { Content string // TODO: support multi-modal content } +// MarshalMessagesToJSON converts a slice of Message structs into a JSON byte slice. +// This is used to create a consistent byte representation for prefix caching calculations, +// allowing us to identify common prefixes between LLM requests and responses. +func MarshalMessagesToJSON(messages ...Message) ([]byte, error) { + if len(messages) == 0 { + return []byte{}, nil + } + var buf bytes.Buffer + for _, msg := range messages { + jsonBytes, err := json.Marshal(msg) + if err != nil { + return []byte{}, err + } + buf.Write(jsonBytes) + buf.WriteString(messageSplit) + } + return buf.Bytes(), nil +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState diff --git a/pkg/epp/scheduling/types/types_test.go b/pkg/epp/scheduling/types/types_test.go new file mode 100644 index 000000000..711b6dffb --- /dev/null +++ b/pkg/epp/scheduling/types/types_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "bytes" + "testing" +) + +func TestMarshalMessagesToJSON(t *testing.T) { + testCases := []struct { + name string + messages []Message + want []byte + wantErr bool + }{ + { + name: "empty messages", + messages: []Message{}, + want: []byte{}, + wantErr: false, + }, + { + name: "single message", + messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},`), + wantErr: false, + }, + { + name: "multiple messages", + messages: []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},{"Role":"assistant","Content":"Hi there!"},`), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := MarshalMessagesToJSON(tc.messages...) + if (err != nil) != tc.wantErr { + t.Errorf("MarshalMessagesToJSON() error = %v, wantErr %v", err, tc.wantErr) + return + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("MarshalMessagesToJSON() got = %s, want %s", string(got), string(tc.want)) + } + }) + } +}