diff --git a/pkg/common/config.go b/pkg/common/config.go index 3afa45dc..5256c655 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -340,6 +340,9 @@ func (c *Configuration) validate() error { if c.KVCacheTransferTimeStdDev < 0 { return errors.New("kv-cache tranfer time standard deviation cannot be negative") } + if float32(c.KVCacheTransferTimeStdDev) > 0.3*float32(c.KVCacheTransferTimePerToken) { + return errors.New("kv-cache tranfer time standard deviation cannot be more than 30% of kv-cache tranfer time") + } if c.KVCacheTransferLatency < 0 { return errors.New("kv-cache tranfer time cannot be negative") diff --git a/pkg/kv-cache/block_cache.go b/pkg/kv-cache/block_cache.go index 9b358c4c..9045dd1b 100644 --- a/pkg/kv-cache/block_cache.go +++ b/pkg/kv-cache/block_cache.go @@ -76,13 +76,14 @@ func (b *blockCache) start(ctx context.Context) { } // startRequest adds a request with its associated block hashes to the cache -func (bc *blockCache) startRequest(requestID string, blocks []uint64) error { +// and returns the number of blocks that were already in the cache +func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, error) { bc.mu.Lock() defer bc.mu.Unlock() if _, exists := bc.requestToBlocks[requestID]; exists { // request with the same id already exists - return fmt.Errorf("request already exists for id %s", requestID) + return 0, fmt.Errorf("request already exists for id %s", requestID) } // divide list of blocks to three lists: @@ -107,7 +108,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error { } if len(bc.usedBlocks)+len(blocksToAdd)+len(blockToMoveToUsed) > bc.maxBlocks { - return errors.New(capacityError) + return 0, errors.New(capacityError) } // for blocks that are already in use - update the reference @@ -148,7 +149,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error { bc.requestToBlocks[requestID] = make([]uint64, len(blocks)) copy(bc.requestToBlocks[requestID], blocks) - return nil + return len(blockAreadyInUse) + len(blockToMoveToUsed), nil } // finishRequest processes the completion of a request, decreasing reference counts @@ -159,7 +160,7 @@ func (bc *blockCache) finishRequest(requestID string) error { // Get blocks associated with this request blockHashes, exists := bc.requestToBlocks[requestID] if !exists { - return errors.New("request not found") + return nil } now := time.Now() diff --git a/pkg/kv-cache/kv_cache.go b/pkg/kv-cache/kv_cache.go index dbbd7645..5c5819db 100644 --- a/pkg/kv-cache/kv_cache.go +++ b/pkg/kv-cache/kv_cache.go @@ -32,6 +32,7 @@ type KVCacheHelper struct { tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys logger logr.Logger blockCache *blockCache + blockSize int } func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCacheHelper, error) { @@ -59,6 +60,7 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach tokensProcessor: tokensProcessor, blockCache: blockCache, logger: logger, + blockSize: config.TokenBlockSize, }, nil } @@ -78,7 +80,7 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest tokens, _, err := h.tokenizer.Encode(prompt, modelName) if err != nil { h.logger.Info("Prompt tokenization failed", "error", err.Error()) - return h.blockCache.startRequest(requestID, make([]uint64, 0)) + return err } // get block keys @@ -90,7 +92,9 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest blockHashes[i] = key.ChunkHash } - return h.blockCache.startRequest(requestID, blockHashes) + nExistingBlocks, err := h.blockCache.startRequest(requestID, blockHashes) + vllmReq.SetNumberOfCachedPromptTokens(nExistingBlocks * h.blockSize) + return err } func (h *KVCacheHelper) OnRequestEnd(vllmReq openaiserverapi.CompletionRequest) error { diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index fce1b44e..172c9ced 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -237,7 +237,7 @@ var _ = Describe("KV cache", Ordered, func() { var err error switch action.action { case actionStartRequest: - err = blockCache.startRequest(action.request.id, action.request.blocks) + _, err = blockCache.startRequest(action.request.id, action.request.blocks) case actionFinishRequest: err = blockCache.finishRequest(action.request.id) } @@ -344,17 +344,21 @@ var _ = Describe("KV cache", Ordered, func() { req4 := testRequest{"req4", []uint64{5, 6}} // blocks 1 and 2 stored - err = blockCache.startRequest(req1.id, req1.blocks) + alreadyInCache, err := blockCache.startRequest(req1.id, req1.blocks) Expect(err).NotTo(HaveOccurred()) + Expect(alreadyInCache).To(Equal(0)) // blocks 3 and 4 stored - err = blockCache.startRequest(req2.id, req2.blocks) + alreadyInCache, err = blockCache.startRequest(req2.id, req2.blocks) Expect(err).NotTo(HaveOccurred()) + Expect(alreadyInCache).To(Equal(0)) // no new blocks stored, reuse of 1 and 3 - err = blockCache.startRequest(req3.id, req3.blocks) + alreadyInCache, err = blockCache.startRequest(req3.id, req3.blocks) Expect(err).NotTo(HaveOccurred()) + Expect(alreadyInCache).To(Equal(2)) // no space left - should fail - err = blockCache.startRequest(req4.id, req4.blocks) + alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks) Expect(err).To(HaveOccurred()) + Expect(alreadyInCache).To(Equal(0)) err = blockCache.finishRequest(req1.id) Expect(err).NotTo(HaveOccurred()) @@ -363,8 +367,9 @@ var _ = Describe("KV cache", Ordered, func() { // now 2 and 4 are not in use // blocks 2 and 4 should be removed, and 5 and 6 stored - err = blockCache.startRequest(req4.id, req4.blocks) + alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks) Expect(err).NotTo(HaveOccurred()) + Expect(alreadyInCache).To(Equal(0)) }() removedBlocks := make([]uint64, 0) @@ -431,7 +436,7 @@ var _ = Describe("KV cache", Ordered, func() { reqID := fmt.Sprintf("req_%d_%d", id, j) blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue) - err := blockCache.startRequest(reqID, blocks) + _, err := blockCache.startRequest(reqID, blocks) if err != nil { // some operations may fail due to cache being full, which is expected Expect(err.Error()).To(Equal(capacityError)) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index cc2a06e0..026a55c4 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -382,7 +382,6 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple if s.config.EnableKVCache && !isChatCompletion { err := s.kvcacheHelper.OnRequestEnd(vllmReq) if err != nil { - // TODO should it be an error with http response error or just a warning? s.logger.Error(err, "kv cache failed to process request end") } } @@ -391,8 +390,7 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple // kv cache is currently supported for /completion API only err = s.kvcacheHelper.OnRequestStart(vllmReq) if err != nil { - // TODO should it be an error with http response error or just a warning? - s.logger.Error(err, "kv cache failed to process request start") + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(err.Error(), fasthttp.StatusInternalServerError, nil), false) } } @@ -490,12 +488,14 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { } s.sendStreamingResponse( &streamingContext{ - ctx: reqCtx.HTTPReqCtx, - isChatCompletion: reqCtx.IsChatCompletion, - model: displayModel, - doRemotePrefill: req.IsDoRemotePrefill(), + ctx: reqCtx.HTTPReqCtx, + isChatCompletion: reqCtx.IsChatCompletion, + model: displayModel, + doRemotePrefill: req.IsDoRemotePrefill(), + nPromptTokens: usageData.PromptTokens, + nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), }, - usageData.PromptTokens, responseTokens, toolCalls, finishReason, usageDataToSend, + responseTokens, toolCalls, finishReason, usageDataToSend, ) } else { if req.IsDoRemoteDecode() { @@ -503,15 +503,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { finishReason = common.RemoteDecodeFinishReason } - s.sendResponse(reqCtx.IsChatCompletion, - reqCtx.HTTPReqCtx, - responseTokens, - toolCalls, - displayModel, - finishReason, - &usageData, - req.IsDoRemoteDecode(), - req.IsDoRemotePrefill()) + s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) } } reqCtx.Wg.Done() @@ -628,17 +620,19 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke } // sendResponse sends response for completion API, supports both completions (text and chat) -// according the value of isChatCompletion +// according the value of isChatCompletion in reqCtx // respTokens - tokenized content to be sent in the response // toolCalls - tool calls to be sent in the response // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). // finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools // usageData - usage (tokens statistics) for this response -func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.RequestCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, - modelName string, finishReason string, usageData *openaiserverapi.Usage, doRemoteDecode bool, doRemotePrefill bool) { - resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, doRemoteDecode) +func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, + modelName string, finishReason string, usageData *openaiserverapi.Usage) { + resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, + reqCtx.CompletionReq.IsDoRemoteDecode()) + ctx := reqCtx.HTTPReqCtx data, err := json.Marshal(resp) if err != nil { ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) @@ -647,8 +641,10 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques // calculate how long to wait before returning the response, time is based on number of tokens nPromptTokens := usageData.PromptTokens + nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens() nGenTokens := usageData.CompletionTokens - totalMillisToWait := s.getTimeToFirstToken(nPromptTokens, doRemotePrefill) + s.getTotalInterTokenLatency(nGenTokens) + ttft := s.getTimeToFirstToken(nPromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill()) + totalMillisToWait := ttft + s.getTotalInterTokenLatency(nGenTokens) time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond) ctx.Response.Header.SetContentType("application/json") @@ -666,7 +662,7 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques } // returns time to first token based on the current request's doRemotePrefill -func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill bool) int { +func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int { if doRemotePrefill { if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 { // is disaggregated PD and ttft is calculated using number of prompt tokens @@ -677,8 +673,8 @@ func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill b return int(common.RandomNorm(float64(s.config.KVCacheTransferLatency), float64(s.config.KVCacheTransferLatencyStdDev))) } if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 { - // is aggregated PD and ttft is calculated using number of prompt tokens - prefillTime := s.config.PrefillOverhead + nPromptTokens*s.config.PrefillTimePerToken + // is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache + prefillTime := s.config.PrefillOverhead + (nPromptTokens-nCachedPromptTokens)*s.config.PrefillTimePerToken return int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev))) } // is aggregated PD and *not* using number of prompt tokens diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 73699fbb..1c9c8805 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -807,7 +807,7 @@ var _ = Describe("Simulator", func() { simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev simulator.config.KVCacheTransferLatency = kvCacheLatency simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev - timeToFirst := simulator.getTimeToFirstToken(1, doREmotePrefill) + timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill) if doREmotePrefill { Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) @@ -838,7 +838,7 @@ var _ = Describe("Simulator", func() { simulator.config.PrefillTimePerToken = 200 simulator.config.PrefillTimeStdDev = 80 - ttft := simulator.getTimeToFirstToken(128, false) + ttft := simulator.getTimeToFirstToken(128, 0, false) Expect(ttft).To(BeNumerically("==", timeToFirstToken)) }) @@ -851,33 +851,60 @@ var _ = Describe("Simulator", func() { simulator.config.PrefillTimePerToken = 200 simulator.config.PrefillTimeStdDev = 80 - ttft := simulator.getTimeToFirstToken(128, false) + ttft := simulator.getTimeToFirstToken(128, 0, false) Expect(ttft).NotTo(BeNumerically("==", 0)) }) - DescribeTable("time to first token is against number of prompt tokens", - func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int) { + DescribeTable("time to first token is against number of prompt tokens with std", + func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int, nCachedTokens int) { simulator.config.TimeToFirstToken = 0 simulator.config.PrefillOverhead = prefillOverhead simulator.config.PrefillTimePerToken = prefillTimePerToken simulator.config.PrefillTimeStdDev = stdDev - ttft := simulator.getTimeToFirstToken(nTokens, false) + ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false) - expectedTTFT := prefillOverhead + prefillTimePerToken*nTokens + expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) + }, + func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int, nCachedTokens int) string { + return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d nCachedTokens: %d", + prefillOverhead, prefillTimePerToken, stdDev, nTokens, nCachedTokens) + }, + Entry("single token", 100, 50, 10, 1, 0), + Entry("single token big std", 100, 50, 70, 1, 0), + Entry("stddev is 0", 100, 50, 0, 1, 0), + Entry("medium overhead, 512 tokens", 200, 1000, 150, 512, 0), + Entry("large overhead, 1024 tokens", 2000, 3000, 800, 1024, 0), + Entry("very long prompt", 150, 200, 70, 20000, 0), + Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 150, 512, 256), + Entry("large overhead, 1024 tokens, 1008 cached", 2000, 3000, 800, 1024, 1008), + Entry("very long prompt, 1024 cached", 150, 200, 70, 20000, 1024), + ) + + DescribeTable("time to first token is against number of prompt tokens", + func(prefillOverhead int, prefillTimePerToken int, nTokens int, nCachedTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = prefillOverhead + simulator.config.PrefillTimePerToken = prefillTimePerToken + simulator.config.PrefillTimeStdDev = 0 + ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false) + expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) + Expect(ttft).To(Equal(expectedTTFT)) }, - func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int) string { - return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d", - prefillOverhead, prefillTimePerToken, stdDev, nTokens) + func(prefillOverhead int, prefillTimePerToken, nTokens int, nCachedTokens int) string { + return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, nTokens: %d nCachedTokens: %d", + prefillOverhead, prefillTimePerToken, nTokens, nCachedTokens) }, - Entry("single token", 100, 50, 70, 1), - Entry("stddev is 0", 100, 50, 0, 1), - Entry("medium overhead, 512 tokens", 200, 1000, 150, 512), - Entry("large overhead, 1024 tokens", 2000, 3000, 1800, 1024), - Entry("very long prompt", 150, 200, 100, 20000), + Entry("single token", 100, 50, 1, 0), + Entry("medium overhead, 512 tokens", 200, 1000, 512, 0), + Entry("large overhead, 1024 tokens", 2000, 3000, 1024, 0), + Entry("very long prompt", 150, 200, 20000, 0), + Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 512, 256), + Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 1024, 128), + Entry("very long prompt, 1024 cached", 150, 200, 20000, 1024), ) It("when not 0, ignore ", func() { @@ -887,7 +914,7 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferTimePerToken = 100 simulator.config.KVCacheTransferTimeStdDev = 0 - ttft := simulator.getTimeToFirstToken(128, true) + ttft := simulator.getTimeToFirstToken(128, 0, true) Expect(ttft).To(BeNumerically("==", 200)) }) @@ -898,7 +925,7 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferTimePerToken = 100 simulator.config.KVCacheTransferTimeStdDev = 0 - ttft := simulator.getTimeToFirstToken(128, true) + ttft := simulator.getTimeToFirstToken(128, 0, true) Expect(ttft).To(BeNumerically("==", 12800)) }) @@ -909,7 +936,7 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT simulator.config.KVCacheTransferTimeStdDev = stddev - ttft := simulator.getTimeToFirstToken(nTokens, true) + ttft := simulator.getTimeToFirstToken(nTokens, 0, true) expectedTTFT := kvCacheTransTPT * nTokens Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index d234114a..5ff1e240 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -28,18 +28,20 @@ import ( ) type streamingContext struct { - ctx *fasthttp.RequestCtx - isChatCompletion bool - model string - creationTime int64 - doRemotePrefill bool + ctx *fasthttp.RequestCtx + isChatCompletion bool + model string + creationTime int64 + doRemotePrefill bool + nPromptTokens int + nCachedPromptTokens int } // sendStreamingResponse creates and sends a streaming response for completion requests of both types (text and chat) // as defined by isChatCompletion // response content is wrapped according SSE format // First token is send after timeToFirstToken milliseconds, every other token is sent after interTokenLatency milliseconds -func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPromptTokens int, responseTokens []string, toolCalls []openaiserverapi.ToolCall, +func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, responseTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason string, usageData *openaiserverapi.Usage) { context.ctx.SetContentType("text/event-stream") context.ctx.SetStatusCode(fasthttp.StatusOK) @@ -67,11 +69,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPrompt if len(toolCalls) > 0 { s.logger.Info("Going to send tools calls") for _, tc := range toolCalls { - s.sendTokenChunks(context, w, nPromptTokens, tc.Function.TokenizedArguments, &tc, finishReason) + s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason) } } else { s.logger.Info("Going to send text", "number of tokens", len(responseTokens)) - s.sendTokenChunks(context, w, nPromptTokens, responseTokens, nil, finishReason) + s.sendTokenChunks(context, w, responseTokens, nil, finishReason) } } @@ -94,9 +96,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPrompt } // sendTokenChunks creates and sends response chunks -func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, nPromptTokens int, genTokens []string, tc *openaiserverapi.ToolCall, finishReason string) { +func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, genTokens []string, + tc *openaiserverapi.ToolCall, finishReason string) { // time to first token delay - time.Sleep(time.Duration(s.getTimeToFirstToken(nPromptTokens, context.doRemotePrefill)) * time.Millisecond) + ttft := s.getTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill) + time.Sleep(time.Duration(ttft) * time.Millisecond) for i, token := range genTokens { if i != 0 { diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index b23104f8..675db162 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -45,6 +45,12 @@ type CompletionRequest interface { IncludeUsage() bool // GetNumberOfPromptTokens returns the number of tokens in the prompt GetNumberOfPromptTokens() int + // GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are + // in the local KV Cache + GetNumberOfCachedPromptTokens() int + // SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are + // in the local KV Cache + SetNumberOfCachedPromptTokens(cachedPromptTokens int) // GetPrompt returns the prompt GetPrompt() string // GetTools() returns tools to use (in chat completion) @@ -85,6 +91,8 @@ type baseCompletionRequest struct { RemoteHost string `json:"remote_host"` // RemotePort is a port of the remote server handling prefill RemotePort int `json:"remote_port"` + // The number of tokens in the prompt that are in the local KV Cache + cachedPromptTokens int } // StreamOptions defines streaming options for streaming requests @@ -117,6 +125,18 @@ func (b *baseCompletionRequest) IsDoRemotePrefill() bool { return b.DoRemotePrefill } +// GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are +// in the local KV Cache +func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int { + return b.cachedPromptTokens +} + +// SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are +// in the local KV Cache +func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { + b.cachedPromptTokens = cachedPromptTokens +} + // CompletionReqCtx is a context passed in the simulator's flow, it contains the request data needed // to generate the simulator's response type CompletionReqCtx struct {