Skip to content

Commit 639b40e

Browse files
authored
Take cached prompt tokens into account in prefill time calculation (#184)
* Take cached prompt tokens into account in prefill time calculation Signed-off-by: Ira <[email protected]> * Review comments Signed-off-by: Ira <[email protected]> --------- Signed-off-by: Ira <[email protected]>
1 parent 5821371 commit 639b40e

File tree

8 files changed

+127
-67
lines changed

8 files changed

+127
-67
lines changed

pkg/common/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ func (c *Configuration) validate() error {
340340
if c.KVCacheTransferTimeStdDev < 0 {
341341
return errors.New("kv-cache tranfer time standard deviation cannot be negative")
342342
}
343+
if float32(c.KVCacheTransferTimeStdDev) > 0.3*float32(c.KVCacheTransferTimePerToken) {
344+
return errors.New("kv-cache tranfer time standard deviation cannot be more than 30% of kv-cache tranfer time")
345+
}
343346

344347
if c.KVCacheTransferLatency < 0 {
345348
return errors.New("kv-cache tranfer time cannot be negative")

pkg/kv-cache/block_cache.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ func (b *blockCache) start(ctx context.Context) {
7676
}
7777

7878
// startRequest adds a request with its associated block hashes to the cache
79-
func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
79+
// and returns the number of blocks that were already in the cache
80+
func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, error) {
8081
bc.mu.Lock()
8182
defer bc.mu.Unlock()
8283

8384
if _, exists := bc.requestToBlocks[requestID]; exists {
8485
// request with the same id already exists
85-
return fmt.Errorf("request already exists for id %s", requestID)
86+
return 0, fmt.Errorf("request already exists for id %s", requestID)
8687
}
8788

8889
// divide list of blocks to three lists:
@@ -107,7 +108,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
107108
}
108109

109110
if len(bc.usedBlocks)+len(blocksToAdd)+len(blockToMoveToUsed) > bc.maxBlocks {
110-
return errors.New(capacityError)
111+
return 0, errors.New(capacityError)
111112
}
112113

113114
// for blocks that are already in use - update the reference
@@ -148,7 +149,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
148149
bc.requestToBlocks[requestID] = make([]uint64, len(blocks))
149150
copy(bc.requestToBlocks[requestID], blocks)
150151

151-
return nil
152+
return len(blockAreadyInUse) + len(blockToMoveToUsed), nil
152153
}
153154

154155
// finishRequest processes the completion of a request, decreasing reference counts
@@ -159,7 +160,7 @@ func (bc *blockCache) finishRequest(requestID string) error {
159160
// Get blocks associated with this request
160161
blockHashes, exists := bc.requestToBlocks[requestID]
161162
if !exists {
162-
return errors.New("request not found")
163+
return nil
163164
}
164165

165166
now := time.Now()

pkg/kv-cache/kv_cache.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type KVCacheHelper struct {
3232
tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys
3333
logger logr.Logger
3434
blockCache *blockCache
35+
blockSize int
3536
}
3637

3738
func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCacheHelper, error) {
@@ -59,6 +60,7 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach
5960
tokensProcessor: tokensProcessor,
6061
blockCache: blockCache,
6162
logger: logger,
63+
blockSize: config.TokenBlockSize,
6264
}, nil
6365
}
6466

@@ -78,7 +80,7 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest
7880
tokens, _, err := h.tokenizer.Encode(prompt, modelName)
7981
if err != nil {
8082
h.logger.Info("Prompt tokenization failed", "error", err.Error())
81-
return h.blockCache.startRequest(requestID, make([]uint64, 0))
83+
return err
8284
}
8385

8486
// get block keys
@@ -90,7 +92,9 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest
9092
blockHashes[i] = key.ChunkHash
9193
}
9294

93-
return h.blockCache.startRequest(requestID, blockHashes)
95+
nExistingBlocks, err := h.blockCache.startRequest(requestID, blockHashes)
96+
vllmReq.SetNumberOfCachedPromptTokens(nExistingBlocks * h.blockSize)
97+
return err
9498
}
9599

96100
func (h *KVCacheHelper) OnRequestEnd(vllmReq openaiserverapi.CompletionRequest) error {

pkg/kv-cache/kv_cache_test.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ var _ = Describe("KV cache", Ordered, func() {
237237
var err error
238238
switch action.action {
239239
case actionStartRequest:
240-
err = blockCache.startRequest(action.request.id, action.request.blocks)
240+
_, err = blockCache.startRequest(action.request.id, action.request.blocks)
241241
case actionFinishRequest:
242242
err = blockCache.finishRequest(action.request.id)
243243
}
@@ -344,17 +344,21 @@ var _ = Describe("KV cache", Ordered, func() {
344344
req4 := testRequest{"req4", []uint64{5, 6}}
345345

346346
// blocks 1 and 2 stored
347-
err = blockCache.startRequest(req1.id, req1.blocks)
347+
alreadyInCache, err := blockCache.startRequest(req1.id, req1.blocks)
348348
Expect(err).NotTo(HaveOccurred())
349+
Expect(alreadyInCache).To(Equal(0))
349350
// blocks 3 and 4 stored
350-
err = blockCache.startRequest(req2.id, req2.blocks)
351+
alreadyInCache, err = blockCache.startRequest(req2.id, req2.blocks)
351352
Expect(err).NotTo(HaveOccurred())
353+
Expect(alreadyInCache).To(Equal(0))
352354
// no new blocks stored, reuse of 1 and 3
353-
err = blockCache.startRequest(req3.id, req3.blocks)
355+
alreadyInCache, err = blockCache.startRequest(req3.id, req3.blocks)
354356
Expect(err).NotTo(HaveOccurred())
357+
Expect(alreadyInCache).To(Equal(2))
355358
// no space left - should fail
356-
err = blockCache.startRequest(req4.id, req4.blocks)
359+
alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks)
357360
Expect(err).To(HaveOccurred())
361+
Expect(alreadyInCache).To(Equal(0))
358362

359363
err = blockCache.finishRequest(req1.id)
360364
Expect(err).NotTo(HaveOccurred())
@@ -363,8 +367,9 @@ var _ = Describe("KV cache", Ordered, func() {
363367
// now 2 and 4 are not in use
364368

365369
// blocks 2 and 4 should be removed, and 5 and 6 stored
366-
err = blockCache.startRequest(req4.id, req4.blocks)
370+
alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks)
367371
Expect(err).NotTo(HaveOccurred())
372+
Expect(alreadyInCache).To(Equal(0))
368373
}()
369374

370375
removedBlocks := make([]uint64, 0)
@@ -431,7 +436,7 @@ var _ = Describe("KV cache", Ordered, func() {
431436
reqID := fmt.Sprintf("req_%d_%d", id, j)
432437
blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue)
433438

434-
err := blockCache.startRequest(reqID, blocks)
439+
_, err := blockCache.startRequest(reqID, blocks)
435440
if err != nil {
436441
// some operations may fail due to cache being full, which is expected
437442
Expect(err.Error()).To(Equal(capacityError))

pkg/llm-d-inference-sim/simulator.go

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
382382
if s.config.EnableKVCache && !isChatCompletion {
383383
err := s.kvcacheHelper.OnRequestEnd(vllmReq)
384384
if err != nil {
385-
// TODO should it be an error with http response error or just a warning?
386385
s.logger.Error(err, "kv cache failed to process request end")
387386
}
388387
}
@@ -391,8 +390,7 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
391390
// kv cache is currently supported for /completion API only
392391
err = s.kvcacheHelper.OnRequestStart(vllmReq)
393392
if err != nil {
394-
// TODO should it be an error with http response error or just a warning?
395-
s.logger.Error(err, "kv cache failed to process request start")
393+
s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(err.Error(), fasthttp.StatusInternalServerError, nil), false)
396394
}
397395
}
398396

@@ -490,28 +488,22 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
490488
}
491489
s.sendStreamingResponse(
492490
&streamingContext{
493-
ctx: reqCtx.HTTPReqCtx,
494-
isChatCompletion: reqCtx.IsChatCompletion,
495-
model: displayModel,
496-
doRemotePrefill: req.IsDoRemotePrefill(),
491+
ctx: reqCtx.HTTPReqCtx,
492+
isChatCompletion: reqCtx.IsChatCompletion,
493+
model: displayModel,
494+
doRemotePrefill: req.IsDoRemotePrefill(),
495+
nPromptTokens: usageData.PromptTokens,
496+
nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(),
497497
},
498-
usageData.PromptTokens, responseTokens, toolCalls, finishReason, usageDataToSend,
498+
responseTokens, toolCalls, finishReason, usageDataToSend,
499499
)
500500
} else {
501501
if req.IsDoRemoteDecode() {
502502
// in case this is prefill pod processing, return special finish reason
503503
finishReason = common.RemoteDecodeFinishReason
504504
}
505505

506-
s.sendResponse(reqCtx.IsChatCompletion,
507-
reqCtx.HTTPReqCtx,
508-
responseTokens,
509-
toolCalls,
510-
displayModel,
511-
finishReason,
512-
&usageData,
513-
req.IsDoRemoteDecode(),
514-
req.IsDoRemotePrefill())
506+
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
515507
}
516508
}
517509
reqCtx.Wg.Done()
@@ -628,17 +620,19 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
628620
}
629621

630622
// sendResponse sends response for completion API, supports both completions (text and chat)
631-
// according the value of isChatCompletion
623+
// according the value of isChatCompletion in reqCtx
632624
// respTokens - tokenized content to be sent in the response
633625
// toolCalls - tool calls to be sent in the response
634626
// modelName - display name returned to the client and used in metrics. It is either the first alias
635627
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
636628
// finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools
637629
// usageData - usage (tokens statistics) for this response
638-
func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.RequestCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall,
639-
modelName string, finishReason string, usageData *openaiserverapi.Usage, doRemoteDecode bool, doRemotePrefill bool) {
640-
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, doRemoteDecode)
630+
func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall,
631+
modelName string, finishReason string, usageData *openaiserverapi.Usage) {
632+
resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
633+
reqCtx.CompletionReq.IsDoRemoteDecode())
641634

635+
ctx := reqCtx.HTTPReqCtx
642636
data, err := json.Marshal(resp)
643637
if err != nil {
644638
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
@@ -647,8 +641,10 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
647641

648642
// calculate how long to wait before returning the response, time is based on number of tokens
649643
nPromptTokens := usageData.PromptTokens
644+
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
650645
nGenTokens := usageData.CompletionTokens
651-
totalMillisToWait := s.getTimeToFirstToken(nPromptTokens, doRemotePrefill) + s.getTotalInterTokenLatency(nGenTokens)
646+
ttft := s.getTimeToFirstToken(nPromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill())
647+
totalMillisToWait := ttft + s.getTotalInterTokenLatency(nGenTokens)
652648
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)
653649

654650
ctx.Response.Header.SetContentType("application/json")
@@ -666,7 +662,7 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
666662
}
667663

668664
// returns time to first token based on the current request's doRemotePrefill
669-
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill bool) int {
665+
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int {
670666
if doRemotePrefill {
671667
if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 {
672668
// is disaggregated PD and ttft is calculated using number of prompt tokens
@@ -677,8 +673,8 @@ func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill b
677673
return int(common.RandomNorm(float64(s.config.KVCacheTransferLatency), float64(s.config.KVCacheTransferLatencyStdDev)))
678674
}
679675
if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 {
680-
// is aggregated PD and ttft is calculated using number of prompt tokens
681-
prefillTime := s.config.PrefillOverhead + nPromptTokens*s.config.PrefillTimePerToken
676+
// is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache
677+
prefillTime := s.config.PrefillOverhead + (nPromptTokens-nCachedPromptTokens)*s.config.PrefillTimePerToken
682678
return int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev)))
683679
}
684680
// is aggregated PD and *not* using number of prompt tokens

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ var _ = Describe("Simulator", func() {
807807
simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev
808808
simulator.config.KVCacheTransferLatency = kvCacheLatency
809809
simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev
810-
timeToFirst := simulator.getTimeToFirstToken(1, doREmotePrefill)
810+
timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill)
811811
if doREmotePrefill {
812812
Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3)))
813813
Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7)))
@@ -838,7 +838,7 @@ var _ = Describe("Simulator", func() {
838838
simulator.config.PrefillTimePerToken = 200
839839
simulator.config.PrefillTimeStdDev = 80
840840

841-
ttft := simulator.getTimeToFirstToken(128, false)
841+
ttft := simulator.getTimeToFirstToken(128, 0, false)
842842

843843
Expect(ttft).To(BeNumerically("==", timeToFirstToken))
844844
})
@@ -851,33 +851,60 @@ var _ = Describe("Simulator", func() {
851851
simulator.config.PrefillTimePerToken = 200
852852
simulator.config.PrefillTimeStdDev = 80
853853

854-
ttft := simulator.getTimeToFirstToken(128, false)
854+
ttft := simulator.getTimeToFirstToken(128, 0, false)
855855
Expect(ttft).NotTo(BeNumerically("==", 0))
856856
})
857857

858-
DescribeTable("time to first token is against number of prompt tokens",
859-
func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int) {
858+
DescribeTable("time to first token is against number of prompt tokens with std",
859+
func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int, nCachedTokens int) {
860860
simulator.config.TimeToFirstToken = 0
861861
simulator.config.PrefillOverhead = prefillOverhead
862862
simulator.config.PrefillTimePerToken = prefillTimePerToken
863863
simulator.config.PrefillTimeStdDev = stdDev
864864

865-
ttft := simulator.getTimeToFirstToken(nTokens, false)
865+
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)
866866

867-
expectedTTFT := prefillOverhead + prefillTimePerToken*nTokens
867+
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
868868
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))
869869
Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7)))
870+
},
871+
func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int, nCachedTokens int) string {
872+
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d nCachedTokens: %d",
873+
prefillOverhead, prefillTimePerToken, stdDev, nTokens, nCachedTokens)
874+
},
875+
Entry("single token", 100, 50, 10, 1, 0),
876+
Entry("single token big std", 100, 50, 70, 1, 0),
877+
Entry("stddev is 0", 100, 50, 0, 1, 0),
878+
Entry("medium overhead, 512 tokens", 200, 1000, 150, 512, 0),
879+
Entry("large overhead, 1024 tokens", 2000, 3000, 800, 1024, 0),
880+
Entry("very long prompt", 150, 200, 70, 20000, 0),
881+
Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 150, 512, 256),
882+
Entry("large overhead, 1024 tokens, 1008 cached", 2000, 3000, 800, 1024, 1008),
883+
Entry("very long prompt, 1024 cached", 150, 200, 70, 20000, 1024),
884+
)
885+
886+
DescribeTable("time to first token is against number of prompt tokens",
887+
func(prefillOverhead int, prefillTimePerToken int, nTokens int, nCachedTokens int) {
888+
simulator.config.TimeToFirstToken = 0
889+
simulator.config.PrefillOverhead = prefillOverhead
890+
simulator.config.PrefillTimePerToken = prefillTimePerToken
891+
simulator.config.PrefillTimeStdDev = 0
870892

893+
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)
894+
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
895+
Expect(ttft).To(Equal(expectedTTFT))
871896
},
872-
func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int) string {
873-
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d",
874-
prefillOverhead, prefillTimePerToken, stdDev, nTokens)
897+
func(prefillOverhead int, prefillTimePerToken, nTokens int, nCachedTokens int) string {
898+
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, nTokens: %d nCachedTokens: %d",
899+
prefillOverhead, prefillTimePerToken, nTokens, nCachedTokens)
875900
},
876-
Entry("single token", 100, 50, 70, 1),
877-
Entry("stddev is 0", 100, 50, 0, 1),
878-
Entry("medium overhead, 512 tokens", 200, 1000, 150, 512),
879-
Entry("large overhead, 1024 tokens", 2000, 3000, 1800, 1024),
880-
Entry("very long prompt", 150, 200, 100, 20000),
901+
Entry("single token", 100, 50, 1, 0),
902+
Entry("medium overhead, 512 tokens", 200, 1000, 512, 0),
903+
Entry("large overhead, 1024 tokens", 2000, 3000, 1024, 0),
904+
Entry("very long prompt", 150, 200, 20000, 0),
905+
Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 512, 256),
906+
Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 1024, 128),
907+
Entry("very long prompt, 1024 cached", 150, 200, 20000, 1024),
881908
)
882909

883910
It("when <kv-cache-transfer-latency> not 0, ignore <kv-cache-transfer-overhead>", func() {
@@ -887,7 +914,7 @@ var _ = Describe("Simulator", func() {
887914
simulator.config.KVCacheTransferTimePerToken = 100
888915
simulator.config.KVCacheTransferTimeStdDev = 0
889916

890-
ttft := simulator.getTimeToFirstToken(128, true)
917+
ttft := simulator.getTimeToFirstToken(128, 0, true)
891918
Expect(ttft).To(BeNumerically("==", 200))
892919
})
893920

@@ -898,7 +925,7 @@ var _ = Describe("Simulator", func() {
898925
simulator.config.KVCacheTransferTimePerToken = 100
899926
simulator.config.KVCacheTransferTimeStdDev = 0
900927

901-
ttft := simulator.getTimeToFirstToken(128, true)
928+
ttft := simulator.getTimeToFirstToken(128, 0, true)
902929
Expect(ttft).To(BeNumerically("==", 12800))
903930
})
904931

@@ -909,7 +936,7 @@ var _ = Describe("Simulator", func() {
909936
simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT
910937
simulator.config.KVCacheTransferTimeStdDev = stddev
911938

912-
ttft := simulator.getTimeToFirstToken(nTokens, true)
939+
ttft := simulator.getTimeToFirstToken(nTokens, 0, true)
913940

914941
expectedTTFT := kvCacheTransTPT * nTokens
915942
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))

0 commit comments

Comments
 (0)