From 6e82e8f860ca59191879f921e00f9d5b44190b7b Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Fri, 3 Oct 2025 23:22:30 +0000 Subject: [PATCH 1/4] Fix function comment and pass existing logger into HandleResponseBodyStreaming --- pkg/epp/handlers/response.go | 2 +- pkg/epp/handlers/response_test.go | 3 ++- pkg/epp/handlers/server.go | 2 +- pkg/epp/requestcontrol/director.go | 4 ++-- pkg/epp/requestcontrol/director_test.go | 4 +++- pkg/epp/server/server_test.go | 3 ++- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 1cbacbae3..5760cbfc6 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -68,7 +68,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { logger := log.FromContext(ctx) - _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) + _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 63b2de0da..290161167 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -21,6 +21,7 @@ import ( "encoding/json" "testing" + "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" @@ -62,7 +63,7 @@ data: [DONE] type mockDirector struct{} -func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { +func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) { return reqCtx, nil } func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 001fdc344..59cde8949 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -55,7 +55,7 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer type Director interface { HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) + HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) GetRandomPod() *backend.Pod } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index d7de39d4a..755a69957 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -27,6 +27,7 @@ import ( "strings" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" @@ -274,8 +275,7 @@ func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers. } // HandleResponseBodyStreaming is called every time a chunk of the response body is received. -func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") +func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index ffd62da36..b0bd554a5 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" @@ -624,6 +625,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { ds := datastore.NewDatastore(t.Context(), nil) mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1)) + logger := log.FromContext(ctx) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -637,7 +639,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyStreaming(ctx, reqCtx) + _, err := director.HandleResponseBodyStreaming(ctx, reqCtx, logger) if err != nil { t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err) } diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index 7220c04b4..f3fa16bfd 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -22,6 +22,7 @@ import ( "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -185,7 +186,7 @@ func (ts *testDirector) HandleResponseReceived(ctx context.Context, reqCtx *hand return reqCtx, nil } -func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { +func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { return reqCtx, nil } From eb6596584eacb4c8fac5cf03be5374a4ae612ae2 Mon Sep 17 00:00:00 2001 From: BenjaminBraunDev Date: Wed, 15 Oct 2025 21:15:40 +0000 Subject: [PATCH 2/4] Revert logging parameter addition, keeping consistent with existing format for plugins --- pkg/epp/handlers/response.go | 2 +- pkg/epp/handlers/response_test.go | 3 +-- pkg/epp/handlers/server.go | 2 +- pkg/epp/requestcontrol/director.go | 4 ++-- pkg/epp/requestcontrol/director_test.go | 4 +--- pkg/epp/server/server_test.go | 3 +-- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 5760cbfc6..1cbacbae3 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -68,7 +68,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) { logger := log.FromContext(ctx) - _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { logger.Error(err, "error in HandleResponseBodyStreaming") } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 290161167..63b2de0da 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -21,7 +21,6 @@ import ( "encoding/json" "testing" - "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" @@ -63,7 +62,7 @@ data: [DONE] type mockDirector struct{} -func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) { +func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { return reqCtx, nil } func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 59cde8949..001fdc344 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -55,7 +55,7 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer type Director interface { HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) - HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext, logger logr.Logger) (*RequestContext, error) + HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) GetRandomPod() *backend.Pod } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 755a69957..d7de39d4a 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -27,7 +27,6 @@ import ( "strings" "time" - "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" @@ -275,7 +274,8 @@ func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers. } // HandleResponseBodyStreaming is called every time a chunk of the response body is received. -func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { +func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk") response := &Response{ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index b0bd554a5..ffd62da36 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -32,7 +32,6 @@ import ( "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client/fake" - "sigs.k8s.io/controller-runtime/pkg/log" v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1" "sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2" @@ -625,7 +624,6 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { ds := datastore.NewDatastore(t.Context(), nil) mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1)) - logger := log.FromContext(ctx) reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -639,7 +637,7 @@ func TestDirector_HandleResponseStreaming(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyStreaming(ctx, reqCtx, logger) + _, err := director.HandleResponseBodyStreaming(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err) } diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go index f3fa16bfd..7220c04b4 100644 --- a/pkg/epp/server/server_test.go +++ b/pkg/epp/server/server_test.go @@ -22,7 +22,6 @@ import ( "testing" pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -186,7 +185,7 @@ func (ts *testDirector) HandleResponseReceived(ctx context.Context, reqCtx *hand return reqCtx, nil } -func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext, logger logr.Logger) (*handlers.RequestContext, error) { +func (ts *testDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { return reqCtx, nil } From bb4fef9df7064d8d69be432933348ced6ef924f3 Mon Sep 17 00:00:00 2001 From: bobzetian Date: Tue, 14 Oct 2025 20:35:16 +0000 Subject: [PATCH 3/4] Add reponse to prefix cache in nonStreaming mode. --- pkg/epp/handlers/server.go | 1 + pkg/epp/requestcontrol/director.go | 13 +- pkg/epp/requestcontrol/director_test.go | 23 +- .../framework/plugins/multi/prefix/plugin.go | 86 +++- .../plugins/multi/prefix/plugin_test.go | 84 ++++ pkg/epp/scheduling/types/llmresponse.go | 135 ++++++ pkg/epp/scheduling/types/llmresponse_test.go | 399 ++++++++++++++++++ 7 files changed, 717 insertions(+), 24 deletions(-) create mode 100644 pkg/epp/scheduling/types/llmresponse.go create mode 100644 pkg/epp/scheduling/types/llmresponse_test.go diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 001fdc344..d12477ff1 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -115,6 +115,7 @@ type Request struct { } type Response struct { Headers map[string]string + Body []byte } type StreamRequestState int diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index d7de39d4a..dea426acc 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -289,13 +289,20 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand // HandleResponseBodyComplete is called when the response body is fully received. func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { - logger := log.FromContext(ctx).WithValues("stage", "bodyChunk") + requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey] + logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID) logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete") + llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body) + if err != nil { + logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") + return reqCtx, err + } response := &Response{ - RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey], + RequestId: requestID, Headers: reqCtx.Response.Headers, + // Currently use the first choice as the response body to process. + Body: llmResponse.GetFirstChoiceContent(), } - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index ffd62da36..65ff9c9d5 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -661,6 +661,23 @@ func TestDirector_HandleResponseComplete(t *testing.T) { mockSched := &mockScheduler{} director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1)) + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ Headers: map[string]string{ @@ -669,6 +686,7 @@ func TestDirector_HandleResponseComplete(t *testing.T) { }, Response: &handlers.Response{ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"}, + Body: []byte(chatCompletionJSON), }, TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } @@ -682,11 +700,14 @@ func TestDirector_HandleResponseComplete(t *testing.T) { t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff) + t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" { + t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) + } } const ( diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index eb45edeab..c7b897732 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -17,6 +17,7 @@ limitations under the License. package prefix import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -28,6 +29,7 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" @@ -117,6 +119,10 @@ var _ plugins.StateData = &SchedulingContextState{} type SchedulingContextState struct { // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. PrefixHashes []BlockHash + // RestBytes is the trailing bytes that not able to fill in a full block and left over. + // If not empty, this will be used as the starting block for the following response that will + // be added to the response as well. This happens especially at the multi-turn scenario. + RestBytes []byte // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -193,9 +199,10 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { // pre score step, hashing prompt and find longest prefix match. - hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, + RestBytes: restBytes, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -301,47 +308,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) // hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. // hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache. // For block i, hash(i) = hash(block i content, hash(i-1)). -func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { +// Also return the extra string. +func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) if request == nil || request.Body == nil { loggerDebug.Info("Request or request data is nil, skipping hashing") - return nil + return nil, nil } userInput, err := getUserInputBytes(request) if err != nil { loggerDebug.Error(err, "Failed to get user input bytes") - return nil + return nil, nil } + prevBlockHash := defaultPrevBlock(request) + return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks) +} - if len(userInput) < cacheBlockSize { - loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize) - return nil - } - if len(userInput) > cacheBlockSize*maxPrefixBlocks { - loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) - userInput = userInput[:maxPrefixBlocks*cacheBlockSize] - } - // Split the body into blocks of size cacheBlockSize. - // If the last block is smaller than cacheBlockSize, it will be ignored. - res := make([]BlockHash, 0, len(userInput)/cacheBlockSize) - // Add the model to the first block hash so that different models have different hashes even with the same body. +func defaultPrevBlock(request *types.LLMRequest) BlockHash { h := xxhash.New() + // Add the model to the first block hash so that different models have different hashes even with the same body. _, _ = h.Write([]byte(request.TargetModel)) if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" { _, _ = h.Write([]byte(cacheSalt)) } - prevBlockHash := BlockHash(h.Sum64()) - for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize { + return BlockHash(h.Sum64()) +} + +func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) { + loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) + if len(input)+prevBlockLength < cacheBlockSize { + loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize) + return nil, input + } + if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks { + loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)] + } + // Split the body into blocks of size cacheBlockSize. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, len(input)/cacheBlockSize) + lastOffSet := 0 + h := xxhash.New() + for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize { h.Reset() - _, _ = h.Write(userInput[i : i+cacheBlockSize]) + _, _ = h.Write(input[i : i+cacheBlockSize]) _, _ = h.Write(toBytes(prevBlockHash)) res = append(res, BlockHash(h.Sum64())) prevBlockHash = res[len(res)-1] + lastOffSet = i + cacheBlockSize } - return res + return res, input[lastOffSet:] } func toBytes(i BlockHash) []byte { @@ -359,6 +378,33 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { return json.Marshal(request.Body.ChatCompletions.Messages) } +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) + if err != nil { + log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) + return + } + p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + var input bytes.Buffer + input.Write(state.RestBytes) + input.Write([]byte(response.Body)) + + server := ServerID(targetPod.NamespacedName) + prevBlockHash := defaultPrevBlock(request) + prevBlockHashLength := 0 + if len(state.PrefixHashes) > 0 { + prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] + prevBlockHashLength = len(state.PrefixHashes) + } + inputBytes := input.Bytes() + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch) + p.wg.Add(1) + go func() { + p.indexer.Add(hashBlocks, server) + p.wg.Done() + }() +} + func getBlockSize(pods []types.Pod, defaultBlockSize int) int { if len(pods) == 0 { return defaultBlockSize diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 176393783..1958f607a 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -30,6 +30,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -199,6 +200,89 @@ func TestPrefixPluginCompletion(t *testing.T) { plugin.wg.Wait() } +func TestPrefixPluginCompletionWithResponse(t *testing.T) { + config := Config{ + DefaultBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUCapacityPerServer: DefaultLRUCapacityPerServer, + } + plugin := New(context.Background(), config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // -- First Request -- + // This initial request will populate the cache. + req1 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaa", + }, + }, + } + scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods) + state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, so the last 2 characters are ignored. + // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). + assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") + + // Simulate that the scheduler picked pod1 for the first request. + schedulingResult := &types.SchedulingResult{ + PrimaryProfileName: "default", + ProfileResults: map[string]*types.ProfileRunResult{ + "default": {TargetPods: []types.Pod{pod1}}, + }, + } + plugin.PreRequest(context.Background(), req1, schedulingResult, 0) + plugin.wg.Wait() + + // -- Simulate Response Completion -- + // The ResponseComplete hook is called. The plugin should update pod1's KV cache + // with the full context of the completed interaction (prompt + response). + // - Initial Prompt: "aaaaaa" + // - Response Body: "bb" + // - Cached Sequence: "aaaaaabb" (length 8) + // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". + plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod()) + plugin.wg.Wait() + + // -- Second Request: Multi-turn Follow-up -- + // This request simulates a follow-up message in a chat. The prompt contains the + // entire conversation history ("aaaaaabb") plus new text ("cc"). + // The plugin should find that the first two blocks ("aaaa", "aabb") of this new + // prompt are already cached on pod1, giving it a perfect match score of 1.0. + // Pod2 has no matching cache entries and should score 0. + req2 := &types.LLMRequest{ + RequestId: uuid.NewString(), + TargetModel: "test-model1", + Body: &types.LLMRequestBody{ + Completions: &types.CompletionsRequest{ + Prompt: "aaaaaabbcc", + }, + }, + } + scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods) + state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String())) + assert.NoError(t, err) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 10, hash block size is 4. The prompt "aaaaaabb" generates 2 hashes. + // The last 2 characters ("cc") are ignored. + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + // It should find a server (pod1) that has cached the prefixes. + assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. + assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") +} + func TestPrefixPluginChatCompletions(t *testing.T) { config := Config{ DefaultBlockSize: 4, diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go new file mode 100644 index 000000000..1061f3ccc --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -0,0 +1,135 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "encoding/json" + "fmt" +) + +// LLMResponse is a structured representation of a parsed LLM response body. +// An LLMResponse must contain exactly one of ChatCompletion or LegacyCompletion. +type LLMResponse struct { + // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. + ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` + // LegacyCompletion is the representation of the OpenAI /v1/completions response body. + LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"` +} + +// GetFirstChoiceContent extracts the primary text content from the first choice +// in either a ChatCompletion or a LegacyCompletion response. +func (res *LLMResponse) GetFirstChoiceContent() string { + if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { + return res.ChatCompletion.Choices[0].Message.Content + } else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 { + return res.LegacyCompletion.Choices[0].Text + } + return "" +} + +// ChatCompletionResponse represents the full response body for the chat completions API. +type ChatCompletionResponse struct { + Choices []ChatChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *ChatCompletionResponse) String() string { + if r == nil { + return nilString + } + contentLen := 0 + if len(r.Choices) > 0 { + contentLen = len(r.Choices[0].Message.Content) + } + return fmt.Sprintf("{ContentLength: %d, Usage: %s}", contentLen, r.Usage) +} + +// ChatChoice represents a single choice in the chat completion response. +type ChatChoice struct { + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatMessage represents the message object within a choice. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// LegacyCompletionResponse represents the full response body for the legacy completions API. +type LegacyCompletionResponse struct { + Choices []LegacyChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +func (r *LegacyCompletionResponse) String() string { + if r == nil { + return nilString + } + textLen := 0 + if len(r.Choices) > 0 { + textLen = len(r.Choices[0].Text) + } + return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) +} + +// LegacyChoice represents a single choice in the legacy completion response. +type LegacyChoice struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +// Usage represents the token usage data common to all response formats. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (u *Usage) String() string { + if u == nil { + return nilString + } + return fmt.Sprintf("{Prompt: %d, Completion: %d, Total: %d}", u.PromptTokens, u.CompletionTokens, u.TotalTokens) +} + +// NewLLMResponseFromBytes initializes an LLMResponse by trying to parse the data +// as a chat completion and then as a legacy completion response. +func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { + if len(body) == 0 { + return nil, fmt.Errorf("input bytes are empty") + } + + // Attempt to unmarshal as a ChatCompletionResponse first. + var chatResp ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err == nil { + // Check if the role is set to distinguish ChatCompletion and LegacyCompletion. + if len(chatResp.Choices) > 0 && chatResp.Choices[0].Message.Role != "" { + return &LLMResponse{ChatCompletion: &chatResp}, nil + } + } + + // Try to unmarshal as a LegacyCompletionResponse. + var legacyResp LegacyCompletionResponse + if err := json.Unmarshal(body, &legacyResp); err == nil { + if len(legacyResp.Choices) > 0 { + return &LLMResponse{LegacyCompletion: &legacyResp}, nil + } + } + + return nil, fmt.Errorf("failed to unmarshal body into any known LLM response format") +} diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go new file mode 100644 index 000000000..8904062a3 --- /dev/null +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -0,0 +1,399 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNewLLMResponseFromBytes(t *testing.T) { + chatCompletionJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + } + }` + + legacyCompletionEmptyChoicesJSON := `{ + "choices": [], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + }` + + chatCompletionEmptyUsageJSON := `{ + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hello!" + }, + "finish_reason": "stop" + } + ] + }` + + legacyCompletionEmptyUsageJSON := `{ + "choices": [ + { + "text": "Hello there!", + "finish_reason": "stop" + } + ] + }` + + invalidJSON := `{"invalid": json}` + unstructuredJSON := `{"foo": "bar"}` + + testCases := []struct { + name string + input []byte + want *LLMResponse + wantError bool + }{ + { + name: "valid chat completion response", + input: []byte(chatCompletionJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 1, + CompletionTokens: 2, + TotalTokens: 3, + }, + }, + }, + wantError: false, + }, + { + name: "valid legacy completion response", + input: []byte(legacyCompletionJSON), + want: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + Usage: &Usage{ + PromptTokens: 4, + CompletionTokens: 5, + TotalTokens: 9, + }, + }, + }, + wantError: false, + }, + { + name: "invalid json", + input: []byte(invalidJSON), + want: nil, + wantError: true, + }, + { + name: "empty input", + input: []byte{}, + want: nil, + wantError: true, + }, + { + name: "unstructured json", + input: []byte(unstructuredJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty choices", + input: []byte(chatCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "legacy completion with empty choices", + input: []byte(legacyCompletionEmptyChoicesJSON), + want: nil, + wantError: true, + }, + { + name: "chat completion with empty usage", + input: []byte(chatCompletionEmptyUsageJSON), + want: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + { + Message: ChatMessage{ + Role: "assistant", + Content: "Hello!", + }, + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + { + name: "legacy completion with empty usage", + input: []byte(legacyCompletionEmptyUsageJSON), + want: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + { + Text: "Hello there!", + FinishReason: "stop", + }, + }, + }, + }, + wantError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewLLMResponseFromBytes(tc.input) + + if (err != nil) != tc.wantError { + t.Errorf("NewLLMResponseFromBytes() error = %v, wantError %v", err, tc.wantError) + return + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("NewLLMResponseFromBytes() (-want +got): %v", diff) + } + }) + } +} + +func TestGetFirstChoiceContent(t *testing.T) { + testCases := []struct { + name string + res *LLMResponse + want string + }{ + { + name: "chatCompletion with choice", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: ChatMessage{Content: "Hello from Chat"}}, + }, + }, + }, + want: "Hello from Chat", + }, + { + name: "legacyCompletion with choice", + res: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + {Text: "Hello from Legacy"}, + }, + }, + }, + want: "Hello from Legacy", + }, + { + name: "chatCompletion with no choices", + res: &LLMResponse{ + ChatCompletion: &ChatCompletionResponse{ + Choices: []ChatChoice{}, + }, + }, + want: "", + }, + { + name: "legacyCompletion with no choices", + res: &LLMResponse{ + LegacyCompletion: &LegacyCompletionResponse{ + Choices: []LegacyChoice{}, + }, + }, + want: "", + }, + { + name: "LLMResponse with all fields nil", + res: &LLMResponse{ + ChatCompletion: nil, + LegacyCompletion: nil, + }, + want: "", + }, + { + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + want: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := tc.res.GetFirstChoiceContent() + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("GetFirstChoiceContent() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUsage_String(t *testing.T) { + var nilUsage *Usage + tests := []struct { + name string + u *Usage + want string + }{ + { + name: "nil usage", + u: nilUsage, + want: nilString, + }, + { + name: "non-nil usage", + u: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + want: "{Prompt: 1, Completion: 2, Total: 3}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("Usage.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChatCompletionResponse_String(t *testing.T) { + var nilResp *ChatCompletionResponse + tests := []struct { + name string + r *ChatCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &ChatCompletionResponse{Choices: []ChatChoice{}, Usage: &Usage{}}, + want: "{ContentLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &ChatCompletionResponse{ + Choices: []ChatChoice{ + {Message: ChatMessage{Content: "hello"}}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{ContentLength: 5, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("ChatCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLegacyCompletionResponse_String(t *testing.T) { + var nilResp *LegacyCompletionResponse + tests := []struct { + name string + r *LegacyCompletionResponse + want string + }{ + { + name: "nil response", + r: nilResp, + want: nilString, + }, + { + name: "response with no choices", + r: &LegacyCompletionResponse{Choices: []LegacyChoice{}, Usage: &Usage{}}, + want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", + }, + { + name: "response with choices", + r: &LegacyCompletionResponse{ + Choices: []LegacyChoice{ + {Text: "hello world"}, + }, + Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }, + want: "{TextLength: 11, Usage: {Prompt: 1, Completion: 2, Total: 3}}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.r.String(); got != tt.want { + t.Errorf("LegacyCompletionResponse.String() = %v, want %v", got, tt.want) + } + }) + } +} From eec941ce8b7924060344187007c4bf555d8668fb Mon Sep 17 00:00:00 2001 From: bobzetian Date: Wed, 15 Oct 2025 04:58:11 +0000 Subject: [PATCH 4/4] make ResponseComplete to accept LLMResponse and update the encoding method of Messages in ChatCompletions. --- pkg/epp/handlers/server.go | 1 + pkg/epp/requestcontrol/director.go | 10 +-- pkg/epp/requestcontrol/director_test.go | 18 +++-- pkg/epp/requestcontrol/plugins.go | 2 +- .../framework/plugins/multi/prefix/plugin.go | 26 ++++--- .../plugins/multi/prefix/plugin_test.go | 53 ++++++++++++-- pkg/epp/scheduling/types/llmresponse.go | 45 ++++++------ pkg/epp/scheduling/types/llmresponse_test.go | 70 ++++++++++--------- pkg/epp/scheduling/types/types.go | 26 ++++++- pkg/epp/scheduling/types/types_test.go | 69 ++++++++++++++++++ 10 files changed, 229 insertions(+), 91 deletions(-) create mode 100644 pkg/epp/scheduling/types/types_test.go diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index d12477ff1..27be31c31 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -303,6 +303,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) break } + reqCtx.Response.Body = body reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) if responseErr != nil { if logger.V(logutil.DEBUG).Enabled() { diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index dea426acc..b91053d0f 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -297,13 +297,7 @@ func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handl logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.") return reqCtx, err } - response := &Response{ - RequestId: requestID, - Headers: reqCtx.Response.Headers, - // Currently use the first choice as the response body to process. - Body: llmResponse.GetFirstChoiceContent(), - } - d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod) + d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod) logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete") return reqCtx, nil @@ -353,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch } } -func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { loggerDebug := log.FromContext(ctx).V(logutil.DEBUG) for _, plugin := range d.requestControlPlugins.responseCompletePlugins { loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName()) diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 65ff9c9d5..896f14aa2 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -677,6 +677,10 @@ func TestDirector_HandleResponseComplete(t *testing.T) { "total_tokens": 3 } }` + wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON)) + if err != nil { + t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err) + } reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ @@ -691,21 +695,15 @@ func TestDirector_HandleResponseComplete(t *testing.T) { TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}}, } - _, err := director.HandleResponseBodyComplete(ctx, reqCtx) + _, err = director.HandleResponseBodyComplete(ctx, reqCtx) if err != nil { t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err) } - if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" { - t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff) - } - if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" { - t.Errorf("Scheduler.OnComplete response headers mismatch (-want +got):\n%s", diff) - } if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff) } - if diff := cmp.Diff("Hello!", pc1.lastRespOnComplete.Body); diff != "" { + if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" { t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff) } } @@ -730,7 +728,7 @@ type testResponseStreaming struct { type testResponseComplete struct { tn plugins.TypedName - lastRespOnComplete *Response + lastRespOnComplete *schedulingtypes.LLMResponse lastTargetPodOnComplete string } @@ -774,7 +772,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli p.lastTargetPodOnStreaming = targetPod.NamespacedName.String() } -func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) { +func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) { p.lastRespOnComplete = response p.lastTargetPodOnComplete = targetPod.NamespacedName.String() } diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go index 44334c68f..1c7eefdac 100644 --- a/pkg/epp/requestcontrol/plugins.go +++ b/pkg/epp/requestcontrol/plugins.go @@ -55,5 +55,5 @@ type ResponseStreaming interface { // ResponseComplete is called by the director after the complete response is sent. type ResponseComplete interface { plugins.Plugin - ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod) + ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) } diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index c7b897732..fa654f73b 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -123,6 +123,8 @@ type SchedulingContextState struct { // If not empty, this will be used as the starting block for the following response that will // be added to the response as well. This happens especially at the multi-turn scenario. RestBytes []byte + // BlockSize is the block size used to caculate the hash of the request/response. + BlockSize int // A map of server to its longest prefix cache match length. PrefixCacheServers map[ServerID]int } @@ -198,11 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin { // Score returns the scoring result for the given list of pods based on context. func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { + blockSize := getBlockSize(pods, p.config.DefaultBlockSize) // pre score step, hashing prompt and find longest prefix match. - hashes, restBytes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch) + hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch) state := &SchedulingContextState{ PrefixHashes: hashes, RestBytes: restBytes, + BlockSize: blockSize, PrefixCacheServers: p.matchLongestPrefix(ctx, hashes), } @@ -233,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) - p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return @@ -251,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche total := len(state.PrefixHashes) matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] - - blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize) - metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize) + metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize) } // matchLongestPrefix returns a map of servers and length of prefix that each server caches. @@ -375,19 +376,25 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...) } -func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { +func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) { state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String())) if err != nil { log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId) return } p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it. + + reponseForKVCache, err := response.FirstChoiceContent() + if err != nil { + log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId) + return + } var input bytes.Buffer input.Write(state.RestBytes) - input.Write([]byte(response.Body)) + input.Write(reponseForKVCache) server := ServerID(targetPod.NamespacedName) prevBlockHash := defaultPrevBlock(request) @@ -396,8 +403,7 @@ func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1] prevBlockHashLength = len(state.PrefixHashes) } - inputBytes := input.Bytes() - hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, inputBytes, p.config.DefaultBlockSize, p.config.MaxPrefixBlocksToMatch) + hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch) p.wg.Add(1) go func() { p.indexer.Add(hashBlocks, server) diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go index 1958f607a..0b83af5f3 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go @@ -30,7 +30,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -201,8 +200,9 @@ func TestPrefixPluginCompletion(t *testing.T) { } func TestPrefixPluginCompletionWithResponse(t *testing.T) { + const defaultBlockSize = 4 config := Config{ - DefaultBlockSize: 4, + DefaultBlockSize: defaultBlockSize, MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, LRUCapacityPerServer: DefaultLRUCapacityPerServer, } @@ -231,6 +231,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { // Total hashes = 1 (for the "aaaa" block) + 1 (for the model prefix). assert.Equal(t, 1, len(state.PrefixHashes), "number of hashes is incorrect") assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers yet") + // The last 2 characters are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") assert.Equal(t, float64(0), scores[pod1], "score for pod1 should be 0 on first request") assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0 on first request") @@ -251,7 +254,16 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { // - Response Body: "bb" // - Cached Sequence: "aaaaaabb" (length 8) // This sequence creates two 4-character blocks to be cached: "aaaa" and "aabb". - plugin.ResponseComplete(context.Background(), req1, &requestcontrol.Response{Body: "bb"}, pod1.GetPod()) + resp1 := &types.LLMResponse{ + Completion: &types.CompletionResponse{ + Choices: []types.CompletionChoice{ + { + Text: "bb", + }, + }, + }, + } + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) plugin.wg.Wait() // -- Second Request: Multi-turn Follow-up -- @@ -278,6 +290,9 @@ func TestPrefixPluginCompletionWithResponse(t *testing.T) { assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") // It should find a server (pod1) that has cached the prefixes. assert.Equal(t, 1, len(state.PrefixCacheServers), "a cached server should have been found") + // The last 2 characters ("cc") are recorded in restBytes of the state. + assert.Equal(t, 2, len(state.RestBytes), "number of restBytes is incorrect") + assert.Equal(t, defaultBlockSize, state.BlockSize, "blockSize is incorrect") // The score for pod1 should be 1.0 because both prompt blocks ("aaaa" and "aabb") were found in its cache. assert.Equal(t, float64(1), scores[pod1], "score for pod1 should be a perfect match") assert.Equal(t, float64(0), scores[pod2], "score for pod2 should be 0") @@ -362,6 +377,19 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { plugin.PreRequest(context.Background(), req1, schedulingResult, 0) plugin.wg.Wait() + resp1 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req1, resp1, pod1.GetPod()) + plugin.wg.Wait() + // Second request adds assistant response and new user message (conversation grows) req2 := &types.LLMRequest{ RequestId: uuid.NewString(), @@ -389,13 +417,27 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore := float64(cachedBlocks) / float64(extendedHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit") + assert.Greater(t, scores[pod1], float64(0.5), "given the response is also prefix cached the cache hit should be well above 0.5") assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit") // Simulate pod1 was picked again plugin.PreRequest(context.Background(), req2, schedulingResult, 0) plugin.wg.Wait() - // Third request continues the conversation even further + resp2 := &types.LLMResponse{ + ChatCompletion: &types.ChatCompletionResponse{ + Choices: []types.ChatChoice{ + { + Message: types.Message{Role: "assistant", Content: "Prefix caching is a technique where..."}, + }, + }, + }, + } + // Trigger to simulate the resp1 is added to the kvCache recording. + plugin.ResponseComplete(context.Background(), req2, resp2, pod1.GetPod()) + plugin.wg.Wait() + + // Third request is the whole above conversation to make the cache hit to 1.0. req3 := &types.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "test-model1", @@ -407,7 +449,6 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { {Role: "assistant", Content: "I'm doing well, thank you! How can I help you today?"}, {Role: "user", Content: "Can you explain how prefix caching works?"}, {Role: "assistant", Content: "Prefix caching is a technique where..."}, - {Role: "user", Content: "That's very helpful, thank you!"}, }, }, }, @@ -424,7 +465,7 @@ func TestPrefixPluginChatCompletionsGrowth(t *testing.T) { cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)] expectedScore = float64(cachedBlocks) / float64(longHashCount) assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit") - assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation") + assert.Equal(t, scores[pod1], float64(1), "cache hit rate should be substantial for growing conversation") assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit") } diff --git a/pkg/epp/scheduling/types/llmresponse.go b/pkg/epp/scheduling/types/llmresponse.go index 1061f3ccc..f6985847a 100644 --- a/pkg/epp/scheduling/types/llmresponse.go +++ b/pkg/epp/scheduling/types/llmresponse.go @@ -18,6 +18,7 @@ package types import ( "encoding/json" + "errors" "fmt" ) @@ -26,19 +27,19 @@ import ( type LLMResponse struct { // ChatCompletion is the representation of the OpenAI /v1/chat/completions response body. ChatCompletion *ChatCompletionResponse `json:"chat_completion,omitempty"` - // LegacyCompletion is the representation of the OpenAI /v1/completions response body. - LegacyCompletion *LegacyCompletionResponse `json:"legacy_completion,omitempty"` + // Completion is the representation of the OpenAI /v1/completions response body. + Completion *CompletionResponse `json:"legacy_completion,omitempty"` } -// GetFirstChoiceContent extracts the primary text content from the first choice -// in either a ChatCompletion or a LegacyCompletion response. -func (res *LLMResponse) GetFirstChoiceContent() string { +// FirstChoiceContent extracts the first choice of the response. +func (res *LLMResponse) FirstChoiceContent() ([]byte, error) { if res.ChatCompletion != nil && len(res.ChatCompletion.Choices) > 0 { - return res.ChatCompletion.Choices[0].Message.Content - } else if res.LegacyCompletion != nil && len(res.LegacyCompletion.Choices) > 0 { - return res.LegacyCompletion.Choices[0].Text + return MarshalMessagesToJSON(res.ChatCompletion.Choices[0].Message) } - return "" + if res.Completion != nil && len(res.Completion.Choices) > 0 { + return []byte(res.Completion.Choices[0].Text), nil + } + return nil, errors.New("no choices found in the LLM response") } // ChatCompletionResponse represents the full response body for the chat completions API. @@ -60,8 +61,8 @@ func (r *ChatCompletionResponse) String() string { // ChatChoice represents a single choice in the chat completion response. type ChatChoice struct { - Message ChatMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` } // ChatMessage represents the message object within a choice. @@ -70,13 +71,13 @@ type ChatMessage struct { Content string `json:"content"` } -// LegacyCompletionResponse represents the full response body for the legacy completions API. -type LegacyCompletionResponse struct { - Choices []LegacyChoice `json:"choices"` - Usage *Usage `json:"usage,omitempty"` +// CompletionResponse represents the full response body for the legacy completions API. +type CompletionResponse struct { + Choices []CompletionChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` } -func (r *LegacyCompletionResponse) String() string { +func (r *CompletionResponse) String() string { if r == nil { return nilString } @@ -87,8 +88,8 @@ func (r *LegacyCompletionResponse) String() string { return fmt.Sprintf("{TextLength: %d, Usage: %v}", textLen, r.Usage) } -// LegacyChoice represents a single choice in the legacy completion response. -type LegacyChoice struct { +// CompletionChoice represents a single choice in the legacy completion response. +type CompletionChoice struct { Text string `json:"text"` FinishReason string `json:"finish_reason"` } @@ -111,7 +112,7 @@ func (u *Usage) String() string { // as a chat completion and then as a legacy completion response. func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { if len(body) == 0 { - return nil, fmt.Errorf("input bytes are empty") + return nil, errors.New("input bytes are empty") } // Attempt to unmarshal as a ChatCompletionResponse first. @@ -124,12 +125,12 @@ func NewLLMResponseFromBytes(body []byte) (*LLMResponse, error) { } // Try to unmarshal as a LegacyCompletionResponse. - var legacyResp LegacyCompletionResponse + var legacyResp CompletionResponse if err := json.Unmarshal(body, &legacyResp); err == nil { if len(legacyResp.Choices) > 0 { - return &LLMResponse{LegacyCompletion: &legacyResp}, nil + return &LLMResponse{Completion: &legacyResp}, nil } } - return nil, fmt.Errorf("failed to unmarshal body into any known LLM response format") + return nil, errors.New("failed to unmarshal body into any known LLM response format") } diff --git a/pkg/epp/scheduling/types/llmresponse_test.go b/pkg/epp/scheduling/types/llmresponse_test.go index 8904062a3..759c9caac 100644 --- a/pkg/epp/scheduling/types/llmresponse_test.go +++ b/pkg/epp/scheduling/types/llmresponse_test.go @@ -109,7 +109,7 @@ func TestNewLLMResponseFromBytes(t *testing.T) { ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ { - Message: ChatMessage{ + Message: Message{ Role: "assistant", Content: "Hello!", }, @@ -129,8 +129,8 @@ func TestNewLLMResponseFromBytes(t *testing.T) { name: "valid legacy completion response", input: []byte(legacyCompletionJSON), want: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ { Text: "Hello there!", FinishReason: "stop", @@ -182,7 +182,7 @@ func TestNewLLMResponseFromBytes(t *testing.T) { ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ { - Message: ChatMessage{ + Message: Message{ Role: "assistant", Content: "Hello!", }, @@ -197,8 +197,8 @@ func TestNewLLMResponseFromBytes(t *testing.T) { name: "legacy completion with empty usage", input: []byte(legacyCompletionEmptyUsageJSON), want: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ { Text: "Hello there!", FinishReason: "stop", @@ -226,33 +226,34 @@ func TestNewLLMResponseFromBytes(t *testing.T) { } } -func TestGetFirstChoiceContent(t *testing.T) { +func TestFirstChoiceContent(t *testing.T) { testCases := []struct { - name string - res *LLMResponse - want string + name string + res *LLMResponse + want []byte + wantError bool }{ { name: "chatCompletion with choice", res: &LLMResponse{ ChatCompletion: &ChatCompletionResponse{ Choices: []ChatChoice{ - {Message: ChatMessage{Content: "Hello from Chat"}}, + {Message: Message{Role: "assistant", Content: "Hello from Chat"}}, }, }, }, - want: "Hello from Chat", + want: []byte(`{"Role":"assistant","Content":"Hello from Chat"},`), }, { name: "legacyCompletion with choice", res: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + Completion: &CompletionResponse{ + Choices: []CompletionChoice{ {Text: "Hello from Legacy"}, }, }, }, - want: "Hello from Legacy", + want: []byte(`Hello from Legacy`), }, { name: "chatCompletion with no choices", @@ -261,37 +262,40 @@ func TestGetFirstChoiceContent(t *testing.T) { Choices: []ChatChoice{}, }, }, - want: "", + wantError: true, }, { name: "legacyCompletion with no choices", res: &LLMResponse{ - LegacyCompletion: &LegacyCompletionResponse{ - Choices: []LegacyChoice{}, + Completion: &CompletionResponse{ + Choices: []CompletionChoice{}, }, }, - want: "", + wantError: true, }, { name: "LLMResponse with all fields nil", res: &LLMResponse{ - ChatCompletion: nil, - LegacyCompletion: nil, + ChatCompletion: nil, + Completion: nil, }, - want: "", + wantError: true, }, { - name: "Empty LLMResponse struct", - res: &LLMResponse{}, - want: "", + name: "Empty LLMResponse struct", + res: &LLMResponse{}, + wantError: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got := tc.res.GetFirstChoiceContent() + got, err := tc.res.FirstChoiceContent() + if tc.wantError != (err != nil) { + t.Errorf("FirstChoiceContent() wantError is %v, but got error: %v", tc.wantError, err) + } if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("GetFirstChoiceContent() mismatch (-want +got):\n%s", diff) + t.Errorf("FirstChoiceContent() mismatch (-want +got):\n%s", diff) } }) } @@ -345,7 +349,7 @@ func TestChatCompletionResponse_String(t *testing.T) { name: "response with choices", r: &ChatCompletionResponse{ Choices: []ChatChoice{ - {Message: ChatMessage{Content: "hello"}}, + {Message: Message{Content: "hello"}}, }, Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, }, @@ -362,10 +366,10 @@ func TestChatCompletionResponse_String(t *testing.T) { } func TestLegacyCompletionResponse_String(t *testing.T) { - var nilResp *LegacyCompletionResponse + var nilResp *CompletionResponse tests := []struct { name string - r *LegacyCompletionResponse + r *CompletionResponse want string }{ { @@ -375,13 +379,13 @@ func TestLegacyCompletionResponse_String(t *testing.T) { }, { name: "response with no choices", - r: &LegacyCompletionResponse{Choices: []LegacyChoice{}, Usage: &Usage{}}, + r: &CompletionResponse{Choices: []CompletionChoice{}, Usage: &Usage{}}, want: "{TextLength: 0, Usage: {Prompt: 0, Completion: 0, Total: 0}}", }, { name: "response with choices", - r: &LegacyCompletionResponse{ - Choices: []LegacyChoice{ + r: &CompletionResponse{ + Choices: []CompletionChoice{ {Text: "hello world"}, }, Usage: &Usage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index c93f0c5ac..18c2a6831 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -17,13 +17,18 @@ limitations under the License. package types import ( + "bytes" + "encoding/json" "fmt" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) -const nilString = "" +const ( + nilString = "" + messageSplit = "," +) // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { @@ -125,6 +130,25 @@ type Message struct { Content string // TODO: support multi-modal content } +// MarshalMessagesToJSON converts a slice of Message structs into a JSON byte slice. +// This is used to create a consistent byte representation for prefix caching calculations, +// allowing us to identify common prefixes between LLM requests and responses. +func MarshalMessagesToJSON(messages ...Message) ([]byte, error) { + if len(messages) == 0 { + return []byte{}, nil + } + var buf bytes.Buffer + for _, msg := range messages { + jsonBytes, err := json.Marshal(msg) + if err != nil { + return []byte{}, err + } + buf.Write(jsonBytes) + buf.WriteString(messageSplit) + } + return buf.Bytes(), nil +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.MetricsState diff --git a/pkg/epp/scheduling/types/types_test.go b/pkg/epp/scheduling/types/types_test.go new file mode 100644 index 000000000..711b6dffb --- /dev/null +++ b/pkg/epp/scheduling/types/types_test.go @@ -0,0 +1,69 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "bytes" + "testing" +) + +func TestMarshalMessagesToJSON(t *testing.T) { + testCases := []struct { + name string + messages []Message + want []byte + wantErr bool + }{ + { + name: "empty messages", + messages: []Message{}, + want: []byte{}, + wantErr: false, + }, + { + name: "single message", + messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},`), + wantErr: false, + }, + { + name: "multiple messages", + messages: []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi there!"}, + }, + want: []byte(`{"Role":"user","Content":"Hello"},{"Role":"assistant","Content":"Hi there!"},`), + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := MarshalMessagesToJSON(tc.messages...) + if (err != nil) != tc.wantErr { + t.Errorf("MarshalMessagesToJSON() error = %v, wantErr %v", err, tc.wantErr) + return + } + + if !bytes.Equal(got, tc.want) { + t.Errorf("MarshalMessagesToJSON() got = %s, want %s", string(got), string(tc.want)) + } + }) + } +}