@@ -125,8 +125,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
125125}
126126
127127// compile-time type assertion
128- var _ framework.Scorer = & Plugin {}
129- var _ requestcontrol.PreRequest = & Plugin {}
128+ var (
129+ _ framework.Scorer = & Plugin {}
130+ _ requestcontrol.PreRequest = & Plugin {}
131+ )
130132
131133// PrefixCachePluginFactory defines the factory function for Prefix plugin.
132134func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
@@ -248,7 +250,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
248250 for server := range cachedServers {
249251 // Update servers with their longest prefix match.
250252 res [server ]++
251-
252253 }
253254 }
254255 }
@@ -260,33 +261,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
260261// For block i, hash(i) = hash(block i content, hash(i-1)).
261262func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
262263 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
263- prompt := []byte (request .Prompt )
264- if len (prompt ) < cacheBlockSize {
265- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
264+ if request == nil || request .Data == nil {
265+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
266266 return nil
267267 }
268- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
269- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
270- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
268+
269+ userInput , err := getUserInputBytes (request )
270+ if err != nil {
271+ loggerDebug .Error (err , "Failed to get user input bytes" )
272+ return nil
271273 }
272- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
273- // If the last block is smaller than cacheBlockSize, it will be ignored.
274- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
275- // Add the model to the first block hash so that different models have different hashes even with the same body.
276274
277- firstBlockSize := cacheBlockSize
278- if len (prompt ) < cacheBlockSize {
279- firstBlockSize = len ( prompt )
275+ if len ( userInput ) < cacheBlockSize {
276+ loggerDebug . Info ( "Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
277+ return nil
280278 }
281- firstBlock := prompt [0 :firstBlockSize ]
282- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
283- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
284-
285- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
286- block := prompt [i : i + cacheBlockSize ]
287- prevBlockHash := res [len (res )- 1 ]
288- block = append (block , toBytes (prevBlockHash )... )
289- res = append (res , BlockHash (xxhash .Sum64 (block )))
279+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
280+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
281+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
282+ }
283+ // Split the body into blocks of size cacheBlockSize.
284+ // If the last block is smaller than cacheBlockSize, it will be ignored.
285+ res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
286+ // Add the model to the first block hash so that different models have different hashes even with the same body.
287+ h := xxhash .New ()
288+ _ , _ = h .Write ([]byte (request .TargetModel ))
289+ prevBlockHash := BlockHash (h .Sum64 ())
290+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
291+ h .Reset ()
292+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
293+ _ , _ = h .Write (toBytes (prevBlockHash ))
294+ res = append (res , BlockHash (h .Sum64 ()))
295+
296+ prevBlockHash = res [len (res )- 1 ]
290297 }
291298 return res
292299}
@@ -296,3 +303,12 @@ func toBytes(i BlockHash) []byte {
296303 binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
297304 return bytes
298305}
306+
307+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
308+ if request .Data .Completions != nil { // assumed to be valid if not nil
309+ return []byte (request .Data .Completions .Prompt ), nil
310+ }
311+
312+ // must be chat-completions request at this point, return bytes of entire messages
313+ return json .Marshal (request .Data .ChatCompletions .Messages )
314+ }
0 commit comments