diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 1cbacbae3..cf7d1ff60 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -17,16 +17,15 @@ limitations under the License. package handlers import ( + "bytes" "context" - "encoding/json" - "fmt" - "strings" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -36,23 +35,19 @@ const ( ) // HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, response map[string]any) (*RequestContext, error) { +func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *RequestContext, body []byte) (*RequestContext, error) { logger := log.FromContext(ctx) - responseBytes, err := json.Marshal(response) + llmResponse, err := types.NewLLMResponseFromBytes(body) if err != nil { - return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err) - } - if response["usage"] != nil { - usg := response["usage"].(map[string]any) - usage := Usage{ - PromptTokens: int(usg["prompt_tokens"].(float64)), - CompletionTokens: int(usg["completion_tokens"].(float64)), - TotalTokens: int(usg["total_tokens"].(float64)), + logger.Error(err, "failed to create LLMResponse from bytes") + } else { + reqCtx.SchedulingResponse = llmResponse + if usage := reqCtx.SchedulingResponse.Usage(); usage != nil { + reqCtx.Usage = usage + logger.V(logutil.VERBOSE).Info("Response generated", "usage", usage) } - reqCtx.Usage = usage - logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage) } - reqCtx.ResponseSize = len(responseBytes) + reqCtx.ResponseSize = len(body) // ResponseComplete is to indicate the response is complete. In non-streaming // case, it will be set to be true once the response is processed; in // streaming case, it will be set to be true once the last chunk is processed. @@ -60,25 +55,36 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // will add the processing for streaming case. reqCtx.ResponseComplete = true - reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true) + reqCtx.respBodyResp = generateResponseBodyResponses(body, true) return s.director.HandleResponseBodyComplete(ctx, reqCtx) } // The function is to handle streaming response if the modelServer is streaming. -func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { +func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, streamBody []byte) { logger := log.FromContext(ctx) _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } - if strings.Contains(responseText, streamingEndMsg) { +} + +func (s *StreamingServer) HandleResponseBodyModelStreamingComplete(ctx context.Context, reqCtx *RequestContext, streamBody []byte) { + logger := log.FromContext(ctx) + if bytes.Contains(streamBody, []byte(streamingEndMsg)) { reqCtx.ResponseComplete = true - resp := parseRespForUsage(ctx, responseText) - reqCtx.Usage = resp.Usage - metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens) - _, err := s.director.HandleResponseBodyComplete(ctx, reqCtx) + resp, err := types.NewLLMResponseFromStream(streamBody) + if err != nil { + logger.Error(err, "error in converting stream response to LLMResponse.") + } else { + reqCtx.SchedulingResponse = resp + if usage := resp.Usage(); usage != nil { + reqCtx.Usage = usage + metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, usage.CompletionTokens) + } + } + _, err = s.director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyComplete") } @@ -153,41 +159,6 @@ func (s *StreamingServer) generateResponseHeaders(reqCtx *RequestContext) []*con return headers } -// Example message if "stream_options": {"include_usage": "true"} is included in the request: -// data: {"id":"...","object":"text_completion","created":1739400043,"model":"food-review-0","choices":[], -// "usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -// -// data: [DONE] -// -// Noticed that vLLM returns two entries in one response. -// We need to strip the `data:` prefix and next Data: [DONE] from the message to fetch response data. -// -// If include_usage is not included in the request, `data: [DONE]` is returned separately, which -// indicates end of streaming. -func parseRespForUsage(ctx context.Context, responseText string) ResponseBody { - response := ResponseBody{} - logger := log.FromContext(ctx) - - lines := strings.Split(responseText, "\n") - for _, line := range lines { - if !strings.HasPrefix(line, streamingRespPrefix) { - continue - } - content := strings.TrimPrefix(line, streamingRespPrefix) - if content == "[DONE]" { - continue - } - - byteSlice := []byte(content) - if err := json.Unmarshal(byteSlice, &response); err != nil { - logger.Error(err, "unmarshaling response body") - continue - } - } - - return response -} - type ResponseBody struct { Usage Usage `json:"usage"` } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 63b2de0da..46d7644b6 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -18,12 +18,12 @@ package handlers import ( "context" - "encoding/json" "testing" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -52,12 +52,33 @@ const ( } ` - streamingBodyWithoutUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":null} - ` + streamingBodyWithoutUsage = ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} - streamingBodyWithUsage = `data: {"id":"cmpl-41764c93-f9d2-4f31-be08-3ba04fa25394","object":"text_completion","created":1740002445,"model":"food-review-0","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -data: [DONE] - ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":null} + + data: [DONE] + ` + + streamingBodyWithUsage = ` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + ` ) type mockDirector struct{} @@ -88,13 +109,13 @@ func TestHandleResponseBody(t *testing.T) { name string body []byte reqCtx *RequestContext - want Usage + want *types.Usage wantErr bool }{ { name: "success", body: []byte(body), - want: Usage{ + want: &types.Usage{ PromptTokens: 11, TotalTokens: 111, CompletionTokens: 100, @@ -110,12 +131,7 @@ func TestHandleResponseBody(t *testing.T) { if reqCtx == nil { reqCtx = &RequestContext{} } - var responseMap map[string]any - marshalErr := json.Unmarshal(test.body, &responseMap) - if marshalErr != nil { - t.Error(marshalErr, "Error unmarshaling request body") - } - _, err := server.HandleResponseBody(ctx, reqCtx, responseMap) + _, err := server.HandleResponseBody(ctx, reqCtx, test.body) if err != nil { if !test.wantErr { t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr) @@ -136,7 +152,7 @@ func TestHandleStreamedResponseBody(t *testing.T) { name string body string reqCtx *RequestContext - want Usage + want *types.Usage wantErr bool }{ { @@ -155,10 +171,10 @@ func TestHandleStreamedResponseBody(t *testing.T) { modelServerStreaming: true, }, wantErr: false, - want: Usage{ - PromptTokens: 7, - TotalTokens: 17, - CompletionTokens: 10, + want: &types.Usage{ + PromptTokens: 5, + TotalTokens: 12, + CompletionTokens: 7, }, }, } @@ -171,7 +187,8 @@ func TestHandleStreamedResponseBody(t *testing.T) { if reqCtx == nil { reqCtx = &RequestContext{} } - server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body) + server.HandleResponseBodyModelStreaming(ctx, reqCtx, []byte(test.body)) + server.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, []byte(test.body)) if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" { t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 001fdc344..62b49ec73 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -85,14 +85,15 @@ type RequestContext struct { RequestReceivedTimestamp time.Time ResponseCompleteTimestamp time.Time RequestSize int - Usage Usage + Usage *schedulingtypes.Usage ResponseSize int ResponseComplete bool ResponseStatusCode string RequestRunning bool Request *Request - SchedulingRequest *schedulingtypes.LLMRequest + SchedulingRequest *schedulingtypes.LLMRequest + SchedulingResponse *schedulingtypes.LLMResponse RequestState StreamRequestState modelServerStreaming bool @@ -267,13 +268,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.respHeaderResp = s.generateResponseHeaderResponse(reqCtx) case *extProcPb.ProcessingRequest_ResponseBody: + body = append(body, v.ResponseBody.Body...) if reqCtx.modelServerStreaming { // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough. - - responseText := string(v.ResponseBody.Body) - s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText) + s.HandleResponseBodyModelStreaming(ctx, reqCtx, v.ResponseBody.Body) if v.ResponseBody.EndOfStream { loggerTrace.Info("stream completed") + s.HandleResponseBodyModelStreamingComplete(ctx, reqCtx, body) reqCtx.ResponseCompleteTimestamp = time.Now() metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) @@ -281,38 +282,36 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) } reqCtx.respBodyResp = generateResponseBodyResponses(v.ResponseBody.Body, v.ResponseBody.EndOfStream) - } else { - body = append(body, v.ResponseBody.Body...) - - // Message is buffered, we can read and decode. - if v.ResponseBody.EndOfStream { - loggerTrace.Info("stream completed") - // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. - // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. - // Using the standard 'err' var will send an immediate error response back to the caller. - var responseErr error - responseErr = json.Unmarshal(body, &responseBody) - if responseErr != nil { - if logger.V(logutil.DEBUG).Enabled() { - logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body)) - } else { - logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body)) - } - reqCtx.respBodyResp = generateResponseBodyResponses(body, true) - break + } else if v.ResponseBody.EndOfStream { + loggerTrace.Info("stream completed") + // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. + // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. + // Using the standard 'err' var will send an immediate error response back to the caller. + var responseErr error + responseErr = json.Unmarshal(body, &responseBody) + if responseErr != nil { + if logger.V(logutil.DEBUG).Enabled() { + logger.V(logutil.DEBUG).Error(responseErr, "Error unmarshalling request body", "body", string(body)) + } else { + logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshalling request body", "body", string(body)) } + reqCtx.respBodyResp = generateResponseBodyResponses(body, true) + break + } - reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) - if responseErr != nil { - if logger.V(logutil.DEBUG).Enabled() { - logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req) - } else { - logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body") - } - } else if reqCtx.ResponseComplete { - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize) + reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, body) + if responseErr != nil { + if logger.V(logutil.DEBUG).Enabled() { + logger.V(logutil.DEBUG).Error(responseErr, "Failed to process response body", "request", req) + } else { + logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body") + } + } else if reqCtx.ResponseComplete { + reqCtx.ResponseCompleteTimestamp = time.Now() + metrics.RecordRequestLatencies(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) + metrics.RecordResponseSizes(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.ResponseSize) + if reqCtx.Usage != nil { + // Response complete does not guarantee the Usage is populated. metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.PromptTokens) metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.Usage.CompletionTokens) } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index d7de39d4a..2142cb367 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -20,6 +20,7 @@ package requestcontrol import ( "context" + "errors" "fmt" "math/rand" "net" @@ -289,14 +290,14 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand // HandleResponseBodyComplete is called when the response body is fully received. func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") + requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey] + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID) logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") - response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], - Headers: reqCtx.Response.Headers, + if reqCtx.SchedulingResponse == nil { + err := errors.New("nil scheduling response from reqCtx") + return reqCtx, err } - - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, reqCtx.SchedulingResponse, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") return reqCtx, nil @@ -346,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch } } -func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseCompletePlugins { loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index ffd62da36..be5bfe9c0 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -661,6 +661,27 @@ func TestDirector_HandleResponseComplete(t *testing.T) { mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1)) + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON)) + if err != nil { + t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err) + } + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ Headers: map[string]string{ @@ -670,23 +691,21 @@ func TestDirector_HandleResponseComplete(t *testing.T) { Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, }, - TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, + SchedulingResponse: wantLLMResponse, + TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyComplete(ctx, reqCtx) + _, err = director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err) } - if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" { - t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff) - } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" { + t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) + } } const ( @@ -709,7 +728,7 @@ type testResponseStreaming struct { type testResponseComplete struct { tn plugins.TypedName - lastRespOnComplete *Response + lastRespOnComplete *schedulingtypes.LLMResponse lastTargetPodOnComplete string } @@ -753,7 +772,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() } -func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { p.lastRespOnComplete = response p.lastTargetPodOnComplete = targetPod.NamespacedName.String() } diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 44334c68f..1c7eefdac 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -55,5 +55,5 @@ type ResponseStreaming interface { // ResponseComplete is called by the director after the complete response is sent. type ResponseComplete interface { plugins.Plugin - ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index eb45edeab..fa654f73b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -17,6 +17,7 @@ limitations under the License. package prefix import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -28,6 +29,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" @@ -117,6 +119,12 @@ var _ plugins.StateData = &SchedulingContextState{} type SchedulingContextState struct { // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. PrefixHashes []BlockHash + // RestBytes is the trailing bytes that not able to fill in a full block and left over. + // If not empty, this will be used as the starting block for the following response that will + // be added to the response as well. This happens especially at the multi-turn scenario. + RestBytes []byte + // BlockSize is the block size used to caculate the hash of the request/response. + BlockSize int // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -192,10 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + blockSize := getBlockSize(pods, p.config.DefaultBlockSize) // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, + RestBytes: restBytes, + BlockSize: blockSize, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -226,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) - p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return @@ -244,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] - - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) - metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) + metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. @@ -301,47 +309,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). -func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { +// Also return the extra string. +func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) if request == nil || request.Body == nil { loggerDebug.Info("Request or request data is nil, skipping hashing") - return nil + return nil, nil } userInput, err := getUserInputBytes(request) if err != nil { loggerDebug.Error(err, "Failed to get user input bytes") - return nil + return nil, nil } + prevBlockHash := defaultPrevBlock(request) + return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks) +} - if len(userInput) < cacheBlockSize { - loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) - return nil - } - if len(userInput) > cacheBlockSize*maxPrefixBlocks { - loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) - userInput = userInput[:maxPrefixBlocks*cacheBlockSize] - } - // Split the body into blocks of size cacheBlockSize. - // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) - // Add the model to the first block hash so that different models have different hashes even with the same body. +func defaultPrevBlock(request *types.LLMRequest) BlockHash { h := xxhash.New() + // Add the model to the first block hash so that different models have different hashes even with the same body. _, _ = h.Write([]byte(request.TargetModel)) if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" { _, _ = h.Write([]byte(cacheSalt)) } - prevBlockHash := BlockHash(h.Sum64()) - for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { + return BlockHash(h.Sum64()) +} + +func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + if len(input)+prevBlockLength < cacheBlockSize { + loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize) + return nil, input + } + if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks { + loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)] + } + // Split the body into blocks of size cacheBlockSize. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, len(input)/cacheBlockSize) + lastOffSet := 0 + h := xxhash.New() + for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize { h.Reset() - _, _ = h.Write(userInput[i : i+cacheBlockSize]) + _, _ = h.Write(input[i : i+cacheBlockSize]) _, _ = h.Write(toBytes(prevBlockHash)) res = append(res, BlockHash(h.Sum64())) prevBlockHash = res[len(res)-1] + lastOffSet = i + cacheBlockSize } - return res + return res, input[lastOffSet:] } func toBytes(i BlockHash) []byte { @@ -356,7 +376,39 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...) +} + +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) { + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + if err != nil { + log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) + return + } + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + + reponseForKVCache, err := response.FirstChoiceContent() + if err != nil { + log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId) + return + } + var input bytes.Buffer + input.Write(state.RestBytes) + input.Write(reponseForKVCache) + + server := ServerID(targetPod.NamespacedName) + prevBlockHash := defaultPrevBlock(request) + prevBlockHashLength := 0 + if len(state.PrefixHashes) > 0 { + prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] + prevBlockHashLength = len(state.PrefixHashes) + } + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch) + p.wg.Add(1) + go func() { + p.indexer.Add(hashBlocks, server) + p.wg.Done() + }() } func getBlockSize(pods []types.Pod, defaultBlockSize int) int { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 176393783..0b83af5f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -199,6 +199,105 @@ func TestPrefixPluginCompletion(t *testing.T) { plugin.wg.Wait() } +func TestPrefixPluginCompletionWithResponse(t *testing.T) { + const defaultBlockSize = 4 + config := Config{ + DefaultBlockSize: defaultBlockSize, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // -- First Request -- + // This initial request will populate the cache. + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaa", + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, so the last 2 characters are ignored. + // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + // The last 2 characters are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") + assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") + + // Simulate that the scheduler picked pod1 for the first request. + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.wg.Wait() + + // -- Simulate Response Completion -- + // The ResponseComplete hook is called. The plugin should update pod1's KV cache + // with the full context of the completed interaction (prompt + response). + // - Initial Prompt: "aaaaaa" + // - Response Body: "bb" + // - Cached Sequence: "aaaaaabb" (length 8) + // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". + resp1 := &types.LLMResponse{ + Completion: &types.CompletionResponse{ + Choices: []types.CompletionChoice{ + { + Text: "bb", + }, + }, + }, + } + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + + // -- Second Request: Multi-turn Follow-up -- + // This request simulates a follow-up message in a chat. The prompt contains the + // entire conversation history ("aaaaaabb") plus new text ("cc"). + // The plugin should find that the first two blocks ("aaaa", "aabb") of this new + // prompt are already cached on pod1, giving it a perfect match score of 1.0. + // Pod2 has no matching cache entries and should score 0. + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaabbcc", + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes. + // The last 2 characters ("cc") are ignored. + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // It should find a server (pod1) that has cached the prefixes. + assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The last 2 characters ("cc") are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") + // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. + assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") +} + func TestPrefixPluginChatCompletions(t *testing.T) { config := Config{ DefaultBlockSize: 4, @@ -278,6 +377,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { plugin.PreRequest(context.Background(), req1, schedulingResult, 0) plugin.wg.Wait() + resp1 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + // Second request adds assistant response and new user message (conversation grows) req2 := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -305,13 +417,27 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore := float64(cachedBlocks) / float64(extendedHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") + assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5") assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult, 0) plugin.wg.Wait() - // Third request continues the conversation even further + resp2 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "Prefix caching is a technique where..."}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod()) + plugin.wg.Wait() + + // Third request is the whole above conversation to make the cache hit to 1.0. req3 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", @@ -323,7 +449,6 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, {Role: "user", Content: "Can you explain how prefix caching works?"}, {Role: "assistant", Content: "Prefix caching is a technique where..."}, - {Role: "user", Content: "That's very helpful, thank you!"}, }, }, }, @@ -340,7 +465,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore = float64(cachedBlocks) / float64(longHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") - assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Equal(t, scores[pod1], float64(1), "cache hit rate should be substantial for growing conversation") assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") } diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go new file mode 100644 index 000000000..3a9322df0 --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -0,0 +1,318 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "sort" +) + +const ( + // StreamDone is the special string indicating the end of a streaming response. + StreamDone = "[DONE]" +) + +// LLMResponse is a structured representation of a parsed LLM response body. +// An LLMResponse must contain exactly one of ChatCompletion or LegacyCompletion. +type LLMResponse struct { + // ChatCompletion is the representation of the OpenAI /vv1/chat/completions response body. + ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` + // Completion is the representation of the OpenAI /v1/completions response body. + Completion *CompletionResponse `json:"legacy_completion,omitempty"` +} + +// FirstChoiceContent extracts the first choice of the response. +func (res *LLMResponse) FirstChoiceContent() ([]byte, error) { + if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { + return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message) + } + if res.Completion != nil && len(res.Completion.Choices) > 0 { + return []byte(res.Completion.Choices[0].Text), nil + } + return nil, errors.New("no choices found in the LLM response") +} + +func (res *LLMResponse) Usage() *Usage { + if res.ChatCompletion != nil { + return res.ChatCompletion.Usage + } + return res.Completion.Usage +} + +// ChatCompletionResponse represents the full response body for the chat completions API. +type ChatCompletionResponse struct { + Choices []ChatChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *ChatCompletionResponse) String() string { + if r == nil { + return nilString + } + contentLen := 0 + if len(r.Choices) > 0 { + contentLen = len(r.Choices[0].Message.Content) + } + return fmt.Sprintf("{ContentLength: %d, Usage: %s}", contentLen, r.Usage) +} + +// ChatChoice represents a single choice in the chat completion response. +type ChatChoice struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatMessage represents the message object within a choice. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// CompletionResponse represents the full response body for the legacy completions API. +type CompletionResponse struct { + Choices []CompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *CompletionResponse) String() string { + if r == nil { + return nilString + } + textLen := 0 + if len(r.Choices) > 0 { + textLen = len(r.Choices[0].Text) + } + return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) +} + +// CompletionChoice represents a single choice in the legacy completion response. +type CompletionChoice struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +// Usage represents the token usage data common to all response formats. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (u *Usage) String() string { + if u == nil { + return nilString + } + return fmt.Sprintf("{Prompt: %d, Completion: %d, Total: %d}", u.PromptTokens, u.CompletionTokens, u.TotalTokens) +} + +// ChatCompletionStreamChoiceDelta represents the delta in a streaming choice. +type ChatCompletionStreamChoiceDelta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` +} + +// ChatCompletionStreamChoice represents a choice in a streaming response. +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// ChatCompletionChunk represents a chunk of a streaming chat completion response. +type ChatCompletionChunk struct { + Choices []ChatCompletionStreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +// CompletionStreamChoice represents a choice in a streaming completion response. +type CompletionStreamChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// CompletionChunk represents a chunk of a streaming completion response. +type CompletionChunk struct { + Choices []CompletionStreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +// NewLLMResponseFromStream initializes an LLMResponse from a streaming response. +func NewLLMResponseFromStream(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, errors.New("input bytes are empty") + } + + lines := bytes.Split(body, []byte("data: ")) + + // Determine stream type from the first data chunk. + for _, line := range lines { + line = bytes.TrimSpace(line) + + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + if bytes.Contains(jsonData, []byte(`"delta":`)) { + return processChatStream(lines) + } + if bytes.Contains(jsonData, []byte(`"text":`)) { + return processCompletionStream(lines) + } + } + + return nil, errors.New("failed to determine stream type or find choices") +} + +func processChatStream(lines [][]byte) (*LLMResponse, error) { + chatChoices := make(map[int]*ChatChoice) + var chatUsage *Usage + + for _, line := range lines { + line = bytes.TrimSpace(line) + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + var chunk ChatCompletionChunk + if err := json.Unmarshal(jsonData, &chunk); err != nil { + continue // Ignore malformed chunks + } + + if chunk.Usage != nil { + chatUsage = chunk.Usage + } + for _, choiceChunk := range chunk.Choices { + if _, ok := chatChoices[choiceChunk.Index]; !ok { + chatChoices[choiceChunk.Index] = &ChatChoice{Message: Message{}} + } + choice := chatChoices[choiceChunk.Index] + choice.Message.Role += choiceChunk.Delta.Role + choice.Message.Content += choiceChunk.Delta.Content + if choiceChunk.FinishReason != "" { + choice.FinishReason = choiceChunk.FinishReason + } + } + } + + if len(chatChoices) == 0 && chatUsage == nil { + return nil, errors.New("no choices or usage found in chat stream") + } + + return aggregateChatStream(chatChoices, chatUsage), nil +} + +func processCompletionStream(lines [][]byte) (*LLMResponse, error) { + completionChoices := make(map[int]*CompletionChoice) + var completionUsage *Usage + + for _, line := range lines { + line = bytes.TrimSpace(line) + jsonData := bytes.TrimPrefix(line, []byte("data: ")) + if len(jsonData) == 0 || string(jsonData) == StreamDone { + continue + } + + var chunk CompletionChunk + if err := json.Unmarshal(jsonData, &chunk); err != nil { + continue // Ignore malformed chunks + } + + if chunk.Usage != nil { + completionUsage = chunk.Usage + } + for _, choiceChunk := range chunk.Choices { + if _, ok := completionChoices[choiceChunk.Index]; !ok { + completionChoices[choiceChunk.Index] = &CompletionChoice{} + } + choice := completionChoices[choiceChunk.Index] + choice.Text += choiceChunk.Text + if choiceChunk.FinishReason != "" { + choice.FinishReason = choiceChunk.FinishReason + } + } + } + + if len(completionChoices) == 0 && completionUsage == nil { + return nil, errors.New("no choices or usage found in completion stream") + } + + return aggregateCompletionStream(completionChoices, completionUsage), nil +} + +func aggregateChatStream(choices map[int]*ChatChoice, usage *Usage) *LLMResponse { + resp := &ChatCompletionResponse{Usage: usage} + keys := make([]int, 0, len(choices)) + for k := range choices { + keys = append(keys, k) + } + sort.Ints(keys) + finalChoices := make([]ChatChoice, len(keys)) + for i, k := range keys { + finalChoices[i] = *choices[k] + } + resp.Choices = finalChoices + + return &LLMResponse{ChatCompletion: resp} +} + +func aggregateCompletionStream(choices map[int]*CompletionChoice, usage *Usage) *LLMResponse { + resp := &CompletionResponse{Usage: usage} + keys := make([]int, 0, len(choices)) + for k := range choices { + keys = append(keys, k) + } + sort.Ints(keys) + finalChoices := make([]CompletionChoice, len(keys)) + for i, k := range keys { + finalChoices[i] = *choices[k] + } + resp.Choices = finalChoices + return &LLMResponse{Completion: resp} +} + +// NewLLMResponseFromBytes initializes an LLMResponse by trying to parse the data +// as a chat completion and then as a legacy completion response. +func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, errors.New("input bytes are empty") + } + + // Attempt to unmarshal as a ChatCompletionResponse first. + var chatResp ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err == nil { + // Check if the role is set to distinguish ChatCompletion and LegacyCompletion. + if len(chatResp.Choices) > 0 && chatResp.Choices[0].Message.Role != "" { + return &LLMResponse{ChatCompletion: &chatResp}, nil + } + } + + // Try to unmarshal as a LegacyCompletionResponse. + var legacyResp CompletionResponse + if err := json.Unmarshal(body, &legacyResp); err == nil { + if len(legacyResp.Choices) > 0 { + return &LLMResponse{Completion: &legacyResp}, nil + } + } + + return nil, errors.New("failed to unmarshal body into any known LLM response format") +} diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go new file mode 100644 index 000000000..698cf3393 --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -0,0 +1,513 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNewLLMResponseFromBytes(t *testing.T) { + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyUsageJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ] + }` + + legacyCompletionEmptyUsageJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ] + }` + + invalidJSON := `{"invalid": json}` + unstructuredJSON := `{"foo": "bar"}` + + testCases := []struct { + name string + input []byte + want *LLMResponse + wantError bool + }{ + { + name: "valid chat completion response", + input: []byte(chatCompletionJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 1, + CompletionTokens: 2, + TotalTokens: 3, + }, + }, + }, + wantError: false, + }, + { + name: "valid legacy completion response", + input: []byte(legacyCompletionJSON), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 4, + CompletionTokens: 5, + TotalTokens: 9, + }, + }, + }, + wantError: false, + }, + { + name: "invalid json", + input: []byte(invalidJSON), + want: nil, + wantError: true, + }, + { + name: "empty input", + input: []byte{}, + want: nil, + wantError: true, + }, + { + name: "unstructured json", + input: []byte(unstructuredJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty choices", + input: []byte(chatCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "legacy completion with empty choices", + input: []byte(legacyCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty usage", + input: []byte(chatCompletionEmptyUsageJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + { + name: "legacy completion with empty usage", + input: []byte(legacyCompletionEmptyUsageJSON), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromBytes(tc.input) + + if (err != nil) != tc.wantError { + t.Errorf("NewLLMResponseFromBytes() error = %v, wantError %v", err, tc.wantError) + return + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromBytes() (-want +got): %v", diff) + } + }) + } +} + +func TestNewLLMResponseFromStream(t *testing.T) { + testCases := []struct { + name string + streamData []byte + want *LLMResponse + wantErr bool + errContains string + }{ + { + name: "valid chat stream with content and usage", + streamData: []byte(` + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" world"}}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]} + + data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + `), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: Message{ + Role: "assistant", + Content: "Hello world", + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 5, + CompletionTokens: 7, + TotalTokens: 12, + }, + }, + }, + wantErr: false, + }, + { + name: "valid completion stream with content and usage", + streamData: []byte(` + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":"Hello"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":" world"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[{"index":0,"text":"","finish_reason":"stop"}]} + + data: {"id":"cmpl-1","object":"text_completion","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":7,"total_tokens":12}} + + data: [DONE] + `), + want: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + { + Text: "Hello world", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 5, + CompletionTokens: 7, + TotalTokens: 12, + }, + }, + }, + }, + { + name: "empty stream data", + streamData: []byte(""), + wantErr: true, + errContains: "input bytes are empty", + }, + { + name: "stream with no choices", + streamData: []byte(`data: [DONE]`), + wantErr: true, + errContains: "failed to determine stream type", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromStream(tc.streamData) + + if tc.wantErr { + if err == nil { + t.Errorf("Expected an error, but got nil") + } + if err != nil && tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("Expected error to contain '%s', but got '%s'", tc.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromStream() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + +func TestFirstChoiceContent(t *testing.T) { + testCases := []struct { + name string + res *LLMResponse + want []byte + wantError bool + }{ + { + name: "chatCompletion with choice", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: Message{Role: "assistant", Content: "Hello from Chat"}}, + }, + }, + }, + want: []byte(`{"Role":"assistant","Content":"Hello from Chat"},`), + }, + { + name: "legacyCompletion with choice", + res: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ + {Text: "Hello from Legacy"}, + }, + }, + }, + want: []byte(`Hello from Legacy`), + }, + { + name: "chatCompletion with no choices", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{}, + }, + }, + wantError: true, + }, + { + name: "legacyCompletion with no choices", + res: &LLMResponse{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{}, + }, + }, + wantError: true, + }, + { + name: "LLMResponse with all fields nil", + res: &LLMResponse{ + ChatCompletion: nil, + Completion: nil, + }, + wantError: true, + }, + { + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.res.FirstChoiceContent() + if tc.wantError != (err != nil) { + t.Errorf("FirstChoiceContent() wantError is %v, but got error: %v", tc.wantError, err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("FirstChoiceContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUsage_String(t *testing.T) { + var nilUsage *Usage + tests := []struct { + name string + u *Usage + want string + }{ + { + name: "nil usage", + u: nilUsage, + want: nilString, + }, + { + name: "non-nil usage", + u: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + want: "{Prompt: 1, Completion: 2, Total: 3}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("Usage.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChatCompletionResponse_String(t *testing.T) { + var nilResp *ChatCompletionResponse + tests := []struct { + name string + r *ChatCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &ChatCompletionResponse{Choices: []ChatChoice{}, Usage: &Usage{}}, + want: "{ContentLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: Message{Content: "hello"}}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{ContentLength: 5, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("ChatCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLegacyCompletionResponse_String(t *testing.T) { + var nilResp *CompletionResponse + tests := []struct { + name string + r *CompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &CompletionResponse{Choices: []CompletionChoice{}, Usage: &Usage{}}, + want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &CompletionResponse{ + Choices: []CompletionChoice{ + {Text: "hello world"}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{TextLength: 11, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("LegacyCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index c93f0c5ac..18c2a6831 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -17,13 +17,18 @@ limitations under the License. package types import ( + "bytes" + "encoding/json" "fmt" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) -const nilString = "" +const ( + nilString = "" + messageSplit = "," +) // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { @@ -125,6 +130,25 @@ type Message struct { Content string // TODO: support multi-modal content } +// MarshalMessagesToJSON converts a slice of Message structs into a JSON byte slice. +// This is used to create a consistent byte representation for prefix caching calculations, +// allowing us to identify common prefixes between LLM requests and responses. +func MarshalMessagesToJSON(messages ...Message) ([]byte, error) { + if len(messages) == 0 { + return []byte{}, nil + } + var buf bytes.Buffer + for _, msg := range messages { + jsonBytes, err := json.Marshal(msg) + if err != nil { + return []byte{}, err + } + buf.Write(jsonBytes) + buf.WriteString(messageSplit) + } + return buf.Bytes(), nil +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState diff --git a/pkg/epp/scheduling/types/types_test.go b/pkg/epp/scheduling/types/types_test.go new file mode 100644 index 000000000..711b6dffb --- /dev/null +++ b/pkg/epp/scheduling/types/types_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "bytes" + "testing" +) + +func TestMarshalMessagesToJSON(t *testing.T) { + testCases := []struct { + name string + messages []Message + want []byte + wantErr bool + }{ + { + name: "empty messages", + messages: []Message{}, + want: []byte{}, + wantErr: false, + }, + { + name: "single message", + messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},`), + wantErr: false, + }, + { + name: "multiple messages", + messages: []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},{"Role":"assistant","Content":"Hi there!"},`), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := MarshalMessagesToJSON(tc.messages...) + if (err != nil) != tc.wantErr { + t.Errorf("MarshalMessagesToJSON() error = %v, wantErr %v", err, tc.wantErr) + return + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("MarshalMessagesToJSON() got = %s, want %s", string(got), string(tc.want)) + } + }) + } +}