Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pkg/epp/requestcontrol/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}
reqCtx.Request.Body["model"] = reqCtx.TargetModelName

prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
if err != nil {
return reqCtx, err
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
}

infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
if infObjective == nil {
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
Expand All @@ -124,7 +125,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
TargetModel: reqCtx.TargetModelName,
Prompt: prompt,
Body: requestBody,
Headers: reqCtx.Request.Headers,
}

Expand Down
66 changes: 41 additions & 25 deletions pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
}

// compile-time type assertion
var _ framework.Scorer = &Plugin{}
var _ requestcontrol.PreRequest = &Plugin{}
var (
_ framework.Scorer = &Plugin{}
_ requestcontrol.PreRequest = &Plugin{}
)

// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
Expand Down Expand Up @@ -248,7 +250,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
for server := range cachedServers {
// Update servers with their longest prefix match.
res[server]++

}
}
}
Expand All @@ -260,33 +261,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
// For block i, hash(i) = hash(block i content, hash(i-1)).
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
prompt := []byte(request.Prompt)
if len(prompt) < cacheBlockSize {
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
if request == nil || request.Body == nil {
loggerDebug.Info("Request or request data is nil, skipping hashing")
return nil
}
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]

userInput, err := getUserInputBytes(request)
if err != nil {
loggerDebug.Error(err, "Failed to get user input bytes")
return nil
}
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
// Add the model to the first block hash so that different models have different hashes even with the same body.

firstBlockSize := cacheBlockSize
if len(prompt) < cacheBlockSize {
firstBlockSize = len(prompt)
if len(userInput) < cacheBlockSize {
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
return nil
}
firstBlock := prompt[0:firstBlockSize]
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))

for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
block := prompt[i : i+cacheBlockSize]
prevBlockHash := res[len(res)-1]
block = append(block, toBytes(prevBlockHash)...)
res = append(res, BlockHash(xxhash.Sum64(block)))
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
}
// Split the body into blocks of size cacheBlockSize.
// If the last block is smaller than cacheBlockSize, it will be ignored.
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
// Add the model to the first block hash so that different models have different hashes even with the same body.
h := xxhash.New()
_, _ = h.Write([]byte(request.TargetModel))
prevBlockHash := BlockHash(h.Sum64())
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
h.Reset()
_, _ = h.Write(userInput[i : i+cacheBlockSize])
_, _ = h.Write(toBytes(prevBlockHash))
res = append(res, BlockHash(h.Sum64()))

prevBlockHash = res[len(res)-1]
}
return res
}
Expand All @@ -296,3 +303,12 @@ func toBytes(i BlockHash) []byte {
binary.LittleEndian.PutUint64(bytes, uint64(i))
return bytes
}

func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
if request.Body.Completions != nil { // assumed to be valid if not nil
return []byte(request.Body.Completions.Prompt), nil
}

// must be chat-completions request at this point, return bytes of entire messages
return json.Marshal(request.Body.ChatCompletions.Messages)
}
Loading