Skip to content

Commit 9f37bfa

Browse files
vMaroonBenjaminBraunDev
authored andcommitted
Refactor LLMRequest: Structured RequestData for Completions & Chat-Completions (kubernetes-sigs#1446)
* - added more useful fields to types.LLMRequest: 1. cleaner API declaration 2. data fields are preserved, after-read transformations are done in plugins 3. prefix-cache scorer does not need naive templating - minor bugfixes and improvements Signed-off-by: Maroon Ayoub <[email protected]> * removed LLMRequestData::String Signed-off-by: Maroon Ayoub <[email protected]> * - rename LLMRequestData to LLMRequestBody - rename LLMRequest.Data to LLMRequest.Body - test refactoring after rebase Signed-off-by: Maroon Ayoub <[email protected]> --------- Signed-off-by: Maroon Ayoub <[email protected]>
1 parent fdecb8a commit 9f37bfa

File tree

6 files changed

+590
-194
lines changed

6 files changed

+590
-194
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,11 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
206206
}
207207
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
208208

209-
prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
209+
requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
210210
if err != nil {
211-
return reqCtx, err
211+
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
212212
}
213+
213214
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
214215
if infObjective == nil {
215216
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
@@ -247,7 +248,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
247248
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
248249
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
249250
TargetModel: reqCtx.TargetModelName,
250-
Prompt: prompt,
251+
Body: requestBody,
251252
Headers: reqCtx.Request.Headers,
252253
TTFTSLO: ttftSLO,
253254
AvgTPOTSLO: avgTPOTSLO,

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
130130
}
131131

132132
// compile-time type assertion
133-
var _ framework.Scorer = &Plugin{}
134-
var _ requestcontrol.PreRequest = &Plugin{}
133+
var (
134+
_ framework.Scorer = &Plugin{}
135+
_ requestcontrol.PreRequest = &Plugin{}
136+
)
135137

136138
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
137139
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
@@ -254,7 +256,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
254256
for server := range cachedServers {
255257
// Update servers with their longest prefix match.
256258
res[server]++
257-
258259
}
259260
}
260261
}
@@ -266,33 +267,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
266267
// For block i, hash(i) = hash(block i content, hash(i-1)).
267268
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
268269
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
269-
prompt := []byte(request.Prompt)
270-
if len(prompt) < cacheBlockSize {
271-
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
270+
if request == nil || request.Body == nil {
271+
loggerDebug.Info("Request or request data is nil, skipping hashing")
272272
return nil
273273
}
274-
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
275-
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
276-
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
274+
275+
userInput, err := getUserInputBytes(request)
276+
if err != nil {
277+
loggerDebug.Error(err, "Failed to get user input bytes")
278+
return nil
277279
}
278-
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
279-
// If the last block is smaller than cacheBlockSize, it will be ignored.
280-
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
281-
// Add the model to the first block hash so that different models have different hashes even with the same body.
282280

283-
firstBlockSize := cacheBlockSize
284-
if len(prompt) < cacheBlockSize {
285-
firstBlockSize = len(prompt)
281+
if len(userInput) < cacheBlockSize {
282+
loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
283+
return nil
286284
}
287-
firstBlock := prompt[0:firstBlockSize]
288-
firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
289-
res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
290-
291-
for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
292-
block := prompt[i : i+cacheBlockSize]
293-
prevBlockHash := res[len(res)-1]
294-
block = append(block, toBytes(prevBlockHash)...)
295-
res = append(res, BlockHash(xxhash.Sum64(block)))
285+
if len(userInput) > cacheBlockSize*maxPrefixBlocks {
286+
loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
287+
userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
288+
}
289+
// Split the body into blocks of size cacheBlockSize.
290+
// If the last block is smaller than cacheBlockSize, it will be ignored.
291+
res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
292+
// Add the model to the first block hash so that different models have different hashes even with the same body.
293+
h := xxhash.New()
294+
_, _ = h.Write([]byte(request.TargetModel))
295+
prevBlockHash := BlockHash(h.Sum64())
296+
for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
297+
h.Reset()
298+
_, _ = h.Write(userInput[i : i+cacheBlockSize])
299+
_, _ = h.Write(toBytes(prevBlockHash))
300+
res = append(res, BlockHash(h.Sum64()))
301+
302+
prevBlockHash = res[len(res)-1]
296303
}
297304
return res
298305
}
@@ -302,3 +309,12 @@ func toBytes(i BlockHash) []byte {
302309
binary.LittleEndian.PutUint64(bytes, uint64(i))
303310
return bytes
304311
}
312+
313+
func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
314+
if request.Body.Completions != nil { // assumed to be valid if not nil
315+
return []byte(request.Body.Completions.Prompt), nil
316+
}
317+
318+
// must be chat-completions request at this point, return bytes of entire messages
319+
return json.Marshal(request.Body.ChatCompletions.Messages)
320+
}

0 commit comments

Comments
 (0)