Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ type Request struct {
}
type Response struct {
Headers map[string]string
Body []byte
}
type StreamRequestState int

Expand Down Expand Up @@ -302,6 +303,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
break
}

reqCtx.Response.Body = body
reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody)
if responseErr != nil {
if logger.V(logutil.DEBUG).Enabled() {
Expand Down
15 changes: 8 additions & 7 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,15 @@ func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *hand

// HandleResponseBodyComplete is called when the response body is fully received.
func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk", requtil.RequestIdHeaderKey, requestID)
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
response := &Response{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
Headers: reqCtx.Response.Headers,
llmResponse, err := schedulingtypes.NewLLMResponseFromBytes(reqCtx.Response.Body)
if err != nil {
logger.Error(err, "HandleResponseBodyComplete: failed to convert the response to LLMResponse.")
return reqCtx, err
}

d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, llmResponse, reqCtx.TargetPod)

logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
return reqCtx, nil
Expand Down Expand Up @@ -346,7 +347,7 @@ func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *sch
}
}

func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
Expand Down
37 changes: 28 additions & 9 deletions pkg/epp/requestcontrol/director_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,27 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
mockSched := &mockScheduler{}
director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1))

chatCompletionJSON := `{
"choices": [
{
"message": {
"role": "assistant",
"content": "Hello!"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3
}
}`
wantLLMResponse, err := schedulingtypes.NewLLMResponseFromBytes([]byte(chatCompletionJSON))
if err != nil {
t.Fatalf("NewLLMResponseFromBytes failed with error: %v", err)
}

reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
Headers: map[string]string{
Expand All @@ -669,24 +690,22 @@ func TestDirector_HandleResponseComplete(t *testing.T) {
},
Response: &handlers.Response{
Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
Body: []byte(chatCompletionJSON),
},
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
}

_, err := director.HandleResponseBodyComplete(ctx, reqCtx)
_, err = director.HandleResponseBodyComplete(ctx, reqCtx)
if err != nil {
t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
}

if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(wantLLMResponse, pc1.lastRespOnComplete); diff != "" {
t.Errorf("Scheduler.OnComplete response body mismatch (-want +got):\n%s", diff)
}
}

const (
Expand All @@ -709,7 +728,7 @@ type testResponseStreaming struct {

type testResponseComplete struct {
tn plugins.TypedName
lastRespOnComplete *Response
lastRespOnComplete *schedulingtypes.LLMResponse
lastTargetPodOnComplete string
}

Expand Down Expand Up @@ -753,7 +772,7 @@ func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *scheduli
p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
}

func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *schedulingtypes.LLMResponse, targetPod *backend.Pod) {
p.lastRespOnComplete = response
p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
}
2 changes: 1 addition & 1 deletion pkg/epp/requestcontrol/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ type ResponseStreaming interface {
// ResponseComplete is called by the director after the complete response is sent.
type ResponseComplete interface {
plugins.Plugin
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod)
}
102 changes: 77 additions & 25 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package prefix

import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
Expand All @@ -28,6 +29,7 @@ import (
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
Expand Down Expand Up @@ -117,6 +119,12 @@ var _ plugins.StateData = &SchedulingContextState{}
type SchedulingContextState struct {
// PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
PrefixHashes []BlockHash
// RestBytes is the trailing bytes that not able to fill in a full block and left over.
// If not empty, this will be used as the starting block for the following response that will
// be added to the response as well. This happens especially at the multi-turn scenario.
RestBytes []byte
// BlockSize is the block size used to caculate the hash of the request/response.
BlockSize int
// A map of server to its longest prefix cache match length.
PrefixCacheServers map[ServerID]int
}
Expand Down Expand Up @@ -192,10 +200,13 @@ func (p *Plugin) WithName(name string) *Plugin {

// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
blockSize := getBlockSize(pods, p.config.DefaultBlockSize)
// pre score step, hashing prompt and find longest prefix match.
hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
hashes, restBytes := hashPrompt(ctx, request, blockSize, p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
RestBytes: restBytes,
BlockSize: blockSize,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
}

Expand Down Expand Up @@ -226,7 +237,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile

state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
return
Expand All @@ -244,9 +254,7 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche

total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]

blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
metrics.RecordPrefixCacheMatch(matchLen*state.BlockSize, total*state.BlockSize)
}

// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
Expand Down Expand Up @@ -301,47 +309,59 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle)
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
// For block i, hash(i) = hash(block i content, hash(i-1)).
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
// Also return the extra string.
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
if request == nil || request.Body == nil {
loggerDebug.Info("Request or request data is nil, skipping hashing")
return nil
return nil, nil
}

userInput, err := getUserInputBytes(request)
if err != nil {
loggerDebug.Error(err, "Failed to get user input bytes")
return nil
return nil, nil
}
prevBlockHash := defaultPrevBlock(request)
return hashInputWithPrevBlockHash(ctx, prevBlockHash, 0, userInput, cacheBlockSize, maxPrefixBlocks)
}

if len(userInput) < cacheBlockSize {
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
return nil
}
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
}
// Split the body into blocks of size cacheBlockSize.
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
// Add the model to the first block hash so that different models have different hashes even with the same body.
func defaultPrevBlock(request *types.LLMRequest) BlockHash {
h := xxhash.New()
// Add the model to the first block hash so that different models have different hashes even with the same body.
_, _ = h.Write([]byte(request.TargetModel))
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
_, _ = h.Write([]byte(cacheSalt))
}

prevBlockHash := BlockHash(h.Sum64())
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
return BlockHash(h.Sum64())
}

func hashInputWithPrevBlockHash(ctx context.Context, prevBlockHash BlockHash, prevBlockLength int, input []byte, cacheBlockSize int, maxPrefixBlocks int) ([]BlockHash, []byte) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
if len(input)+prevBlockLength < cacheBlockSize {
loggerDebug.Info("Request body too small for prefix cache", "size", len(input), "block size", cacheBlockSize)
return nil, input
}
if len(input)+prevBlockLength > cacheBlockSize*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(input), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
input = input[:(maxPrefixBlocks*cacheBlockSize - prevBlockLength)]
}
// Split the body into blocks of size cacheBlockSize.
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, len(input)/cacheBlockSize)
lastOffSet := 0
h := xxhash.New()
for i := 0; i+cacheBlockSize <= len(input); i += cacheBlockSize {
h.Reset()
_, _ = h.Write(userInput[i : i+cacheBlockSize])
_, _ = h.Write(input[i : i+cacheBlockSize])
_, _ = h.Write(toBytes(prevBlockHash))
res = append(res, BlockHash(h.Sum64()))

prevBlockHash = res[len(res)-1]
lastOffSet = i + cacheBlockSize
}
return res
return res, input[lastOffSet:]
}

func toBytes(i BlockHash) []byte {
Expand All @@ -356,7 +376,39 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
}

// must be chat-completions request at this point, return bytes of entire messages
return json.Marshal(request.Body.ChatCompletions.Messages)
return types.MarshalMessagesToJSON(request.Body.ChatCompletions.Messages...)
}

func (p *Plugin) ResponseComplete(ctx context.Context, request *types.LLMRequest, response *types.LLMResponse, targetPod *backend.Pod) {
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](p.pluginState, request.RequestId, plugins.StateKey(p.TypedName().String()))
if err != nil {
log.FromContext(ctx).Error(err, "failed to read prefix plugin state", "requestID", request.RequestId)
return
}
p.pluginState.Delete(request.RequestId) // delete the state explicitly after completing using it.

reponseForKVCache, err := response.FirstChoiceContent()
if err != nil {
log.FromContext(ctx).Error(err, "failed to get first choice content", "requestID", request.RequestId)
return
}
var input bytes.Buffer
input.Write(state.RestBytes)
input.Write(reponseForKVCache)

server := ServerID(targetPod.NamespacedName)
prevBlockHash := defaultPrevBlock(request)
prevBlockHashLength := 0
if len(state.PrefixHashes) > 0 {
prevBlockHash = state.PrefixHashes[len(state.PrefixHashes)-1]
prevBlockHashLength = len(state.PrefixHashes)
}
hashBlocks, _ := hashInputWithPrevBlockHash(ctx, prevBlockHash, prevBlockHashLength, input.Bytes(), state.BlockSize, p.config.MaxPrefixBlocksToMatch)
p.wg.Add(1)
go func() {
p.indexer.Add(hashBlocks, server)
p.wg.Done()
}()
}

func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
Expand Down
Loading