@@ -250,33 +250,39 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
250250// For block i, hash(i) = hash(block i content, hash(i-1)).
251251func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
252252 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
253- prompt := []byte (request .Prompt )
254- if len (prompt ) < cacheBlockSize {
255- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
253+ if request == nil || request .Data == nil {
254+ loggerDebug .Info ("Request or request data is nil, skipping hashing" )
256255 return nil
257256 }
258- if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
259- loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
260- prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
257+
258+ userInput , err := getUserInputBytes (request )
259+ if err != nil {
260+ loggerDebug .Error (err , "Failed to get user input bytes" )
261+ return nil
262+ }
263+
264+ if len (userInput ) < cacheBlockSize {
265+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
266+ return nil
267+ }
268+ if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
269+ loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
270+ userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
261271 }
262272 // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
263273 // If the last block is smaller than cacheBlockSize, it will be ignored.
264- res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
274+ res := make ([]BlockHash , 0 , 1 + len (userInput )/ cacheBlockSize )
265275 // Add the model to the first block hash so that different models have different hashes even with the same body.
266-
267- firstBlockSize := cacheBlockSize
268- if len (prompt ) < cacheBlockSize {
269- firstBlockSize = len (prompt )
270- }
271- firstBlock := prompt [0 :firstBlockSize ]
272- firstBlockWithModel := append ([]byte (request .TargetModel ), firstBlock ... )
273- res = append (res , BlockHash (xxhash .Sum64 (firstBlockWithModel )))
274-
275- for i := cacheBlockSize ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
276- block := prompt [i : i + cacheBlockSize ]
277- prevBlockHash := res [len (res )- 1 ]
278- block = append (block , toBytes (prevBlockHash )... )
279- res = append (res , BlockHash (xxhash .Sum64 (block )))
276+ h := xxhash .New ()
277+ _ , _ = h .Write ([]byte (request .TargetModel ))
278+ prevBlockHash := BlockHash (h .Sum64 ())
279+ for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
280+ h .Reset ()
281+ _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
282+ _ , _ = h .Write (toBytes (prevBlockHash ))
283+ res = append (res , BlockHash (h .Sum64 ()))
284+
285+ prevBlockHash = res [len (res )- 1 ]
280286 }
281287 return res
282288}
@@ -286,3 +292,12 @@ func toBytes(i BlockHash) []byte {
286292 binary .LittleEndian .PutUint64 (bytes , uint64 (i ))
287293 return bytes
288294}
295+
296+ func getUserInputBytes (request * types.LLMRequest ) ([]byte , error ) {
297+ if request .Data .Completions != nil { // assumed to be valid if not nil
298+ return []byte (request .Data .Completions .Prompt ), nil
299+ }
300+
301+ // must be chat-completions request at this point, return bytes of entire messages
302+ return json .Marshal (request .Data .ChatCompletions .Messages )
303+ }
0 commit comments