@@ -123,8 +123,10 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
123123}
124124
125125// compile-time type assertion
126- var _ framework.Scorer = & Plugin {}
127- var _ requestcontrol.PreRequest = & Plugin {}
126+ var (
127+ _ framework.Scorer = & Plugin {}
128+ _ requestcontrol.PreRequest = & Plugin {}
129+ )
128130
129131// PrefixCachePluginFactory defines the factory function for Prefix plugin.
130132func PrefixCachePluginFactory (name string , rawParameters json.RawMessage , handle plugins.Handle ) (plugins.Plugin , error ) {
@@ -238,7 +240,6 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
238240 for server := range cachedServers {
239241 // Update servers with their longest prefix match.
240242 res [server ]++
241-
242243 }
243244 }
244245 }
@@ -250,33 +251,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250251// For block i, hash(i) = hash(block i content, hash(i-1)).
251252func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
252253 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
253- prompt := []byte (request .Prompt )
254- if len (prompt ) < cacheBlockSize {
255- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
254+ if request == nil || request .Data == nil {
255+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
256256 return nil
257257 }
258- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
259- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
260- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
258+
259+ userInput , err := getUserInputBytes (request )
260+ if err != nil {
261+ loggerDebug .Error (err , "Failed to get user input bytes" )
262+ return nil
261263 }
262- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263- // If the last block is smaller than cacheBlockSize, it will be ignored.
264- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
265- // Add the model to the first block hash so that different models have different hashes even with the same body.
266264
267- firstBlockSize := cacheBlockSize
268- if len (prompt ) < cacheBlockSize {
269- firstBlockSize = len ( prompt )
265+ if len ( userInput ) < cacheBlockSize {
266+ loggerDebug . Info ( "Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
267+ return nil
270268 }
271- firstBlock := prompt [0 :firstBlockSize ]
272- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
273- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
274-
275- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
276- block := prompt [i : i + cacheBlockSize ]
277- prevBlockHash := res [len (res )- 1 ]
278- block = append (block , toBytes (prevBlockHash )... )
279- res = append (res , BlockHash (xxhash .Sum64 (block )))
269+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
270+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
271+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
272+ }
273+ // Split the body into blocks of size cacheBlockSize.
274+ // If the last block is smaller than cacheBlockSize, it will be ignored.
275+ res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
276+ // Add the model to the first block hash so that different models have different hashes even with the same body.
277+ h := xxhash .New ()
278+ _ , _ = h .Write ([]byte (request .TargetModel ))
279+ prevBlockHash := BlockHash (h .Sum64 ())
280+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
281+ h .Reset ()
282+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
283+ _ , _ = h .Write (toBytes (prevBlockHash ))
284+ res = append (res , BlockHash (h .Sum64 ()))
285+
286+ prevBlockHash = res [len (res )- 1 ]
280287 }
281288 return res
282289}
@@ -286,3 +293,12 @@ func toBytes(i BlockHash) []byte {
286293 binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
287294 return bytes
288295}
296+
297+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
298+ if request .Data .Completions != nil { // assumed to be valid if not nil
299+ return []byte (request .Data .Completions .Prompt ), nil
300+ }
301+
302+ // must be chat-completions request at this point, return bytes of entire messages
303+ return json .Marshal (request .Data .ChatCompletions .Messages )
304+ }
0 commit comments