@@ -130,8 +130,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
130130}
131131
132132// compile-time type assertion
133- var _ framework.Scorer = & Plugin {}
134- var _ requestcontrol.PreRequest = & Plugin {}
133+ var (
134+ _ framework.Scorer = & Plugin {}
135+ _ requestcontrol.PreRequest = & Plugin {}
136+ )
135137
136138// PrefixCachePluginFactory defines the factory function for Prefix plugin.
137139func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
@@ -254,7 +256,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
254256 for server := range cachedServers {
255257 // Update servers with their longest prefix match.
256258 res [server ]++
257-
258259 }
259260 }
260261 }
@@ -266,33 +267,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
266267// For block i, hash(i) = hash(block i content, hash(i-1)).
267268func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
268269 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
269- prompt := []byte (request .Prompt )
270- if len (prompt ) < cacheBlockSize {
271- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
270+ if request == nil || request .Body == nil {
271+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
272272 return nil
273273 }
274- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
275- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
276- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
274+
275+ userInput , err := getUserInputBytes (request )
276+ if err != nil {
277+ loggerDebug .Error (err , "Failed to get user input bytes" )
278+ return nil
277279 }
278- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
279- // If the last block is smaller than cacheBlockSize, it will be ignored.
280- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
281- // Add the model to the first block hash so that different models have different hashes even with the same body.
282280
283- firstBlockSize := cacheBlockSize
284- if len (prompt ) < cacheBlockSize {
285- firstBlockSize = len ( prompt )
281+ if len ( userInput ) < cacheBlockSize {
282+ loggerDebug . Info ( "Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
283+ return nil
286284 }
287- firstBlock := prompt [0 :firstBlockSize ]
288- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
289- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
290-
291- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
292- block := prompt [i : i + cacheBlockSize ]
293- prevBlockHash := res [len (res )- 1 ]
294- block = append (block , toBytes (prevBlockHash )... )
295- res = append (res , BlockHash (xxhash .Sum64 (block )))
285+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
286+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
287+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
288+ }
289+ // Split the body into blocks of size cacheBlockSize.
290+ // If the last block is smaller than cacheBlockSize, it will be ignored.
291+ res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
292+ // Add the model to the first block hash so that different models have different hashes even with the same body.
293+ h := xxhash .New ()
294+ _ , _ = h .Write ([]byte (request .TargetModel ))
295+ prevBlockHash := BlockHash (h .Sum64 ())
296+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
297+ h .Reset ()
298+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
299+ _ , _ = h .Write (toBytes (prevBlockHash ))
300+ res = append (res , BlockHash (h .Sum64 ()))
301+
302+ prevBlockHash = res [len (res )- 1 ]
296303 }
297304 return res
298305}
@@ -302,3 +309,12 @@ func toBytes(i BlockHash) []byte {
302309 binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
303310 return bytes
304311}
312+
313+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
314+ if request .Body .Completions != nil { // assumed to be valid if not nil
315+ return []byte (request .Body .Completions .Prompt ), nil
316+ }
317+
318+ // must be chat-completions request at this point, return bytes of entire messages
319+ return json .Marshal (request .Body .ChatCompletions .Messages )
320+ }
0 commit comments