Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ func (c *Configuration) validate() error {
if c.KVCacheTransferTimeStdDev < 0 {
return errors.New("kv-cache tranfer time standard deviation cannot be negative")
}
if float32(c.KVCacheTransferTimeStdDev) > 0.3*float32(c.KVCacheTransferTimePerToken) {
return errors.New("kv-cache tranfer time standard deviation cannot be more than 30% of kv-cache tranfer time")
}

if c.KVCacheTransferLatency < 0 {
return errors.New("kv-cache tranfer time cannot be negative")
Expand Down
11 changes: 6 additions & 5 deletions pkg/kv-cache/block_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ func (b *blockCache) start(ctx context.Context) {
}

// startRequest adds a request with its associated block hashes to the cache
func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
// and returns the number of blocks that were already in the cache
func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, error) {
bc.mu.Lock()
defer bc.mu.Unlock()

if _, exists := bc.requestToBlocks[requestID]; exists {
// request with the same id already exists
return fmt.Errorf("request already exists for id %s", requestID)
return 0, fmt.Errorf("request already exists for id %s", requestID)
}

// divide list of blocks to three lists:
Expand All @@ -107,7 +108,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
}

if len(bc.usedBlocks)+len(blocksToAdd)+len(blockToMoveToUsed) > bc.maxBlocks {
return errors.New(capacityError)
return 0, errors.New(capacityError)
}

// for blocks that are already in use - update the reference
Expand Down Expand Up @@ -148,7 +149,7 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) error {
bc.requestToBlocks[requestID] = make([]uint64, len(blocks))
copy(bc.requestToBlocks[requestID], blocks)

return nil
return len(blockAreadyInUse) + len(blockToMoveToUsed), nil
}

// finishRequest processes the completion of a request, decreasing reference counts
Expand All @@ -159,7 +160,7 @@ func (bc *blockCache) finishRequest(requestID string) error {
// Get blocks associated with this request
blockHashes, exists := bc.requestToBlocks[requestID]
if !exists {
return errors.New("request not found")
return nil
}

now := time.Now()
Expand Down
8 changes: 6 additions & 2 deletions pkg/kv-cache/kv_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type KVCacheHelper struct {
tokensProcessor kvblock.TokenProcessor // turns tokens to kv block keys
logger logr.Logger
blockCache *blockCache
blockSize int
}

func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCacheHelper, error) {
Expand Down Expand Up @@ -59,6 +60,7 @@ func NewKVCacheHelper(config *common.Configuration, logger logr.Logger) (*KVCach
tokensProcessor: tokensProcessor,
blockCache: blockCache,
logger: logger,
blockSize: config.TokenBlockSize,
}, nil
}

Expand All @@ -78,7 +80,7 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest
tokens, _, err := h.tokenizer.Encode(prompt, modelName)
if err != nil {
h.logger.Info("Prompt tokenization failed", "error", err.Error())
return h.blockCache.startRequest(requestID, make([]uint64, 0))
return err
}

// get block keys
Expand All @@ -90,7 +92,9 @@ func (h *KVCacheHelper) OnRequestStart(vllmReq openaiserverapi.CompletionRequest
blockHashes[i] = key.ChunkHash
}

return h.blockCache.startRequest(requestID, blockHashes)
nExistingBlocks, err := h.blockCache.startRequest(requestID, blockHashes)
vllmReq.SetNumberOfCachedPromptTokens(nExistingBlocks * h.blockSize)
return err
}

func (h *KVCacheHelper) OnRequestEnd(vllmReq openaiserverapi.CompletionRequest) error {
Expand Down
19 changes: 12 additions & 7 deletions pkg/kv-cache/kv_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ var _ = Describe("KV cache", Ordered, func() {
var err error
switch action.action {
case actionStartRequest:
err = blockCache.startRequest(action.request.id, action.request.blocks)
_, err = blockCache.startRequest(action.request.id, action.request.blocks)
case actionFinishRequest:
err = blockCache.finishRequest(action.request.id)
}
Expand Down Expand Up @@ -344,17 +344,21 @@ var _ = Describe("KV cache", Ordered, func() {
req4 := testRequest{"req4", []uint64{5, 6}}

// blocks 1 and 2 stored
err = blockCache.startRequest(req1.id, req1.blocks)
alreadyInCache, err := blockCache.startRequest(req1.id, req1.blocks)
Expect(err).NotTo(HaveOccurred())
Expect(alreadyInCache).To(Equal(0))
// blocks 3 and 4 stored
err = blockCache.startRequest(req2.id, req2.blocks)
alreadyInCache, err = blockCache.startRequest(req2.id, req2.blocks)
Expect(err).NotTo(HaveOccurred())
Expect(alreadyInCache).To(Equal(0))
// no new blocks stored, reuse of 1 and 3
err = blockCache.startRequest(req3.id, req3.blocks)
alreadyInCache, err = blockCache.startRequest(req3.id, req3.blocks)
Expect(err).NotTo(HaveOccurred())
Expect(alreadyInCache).To(Equal(2))
// no space left - should fail
err = blockCache.startRequest(req4.id, req4.blocks)
alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks)
Expect(err).To(HaveOccurred())
Expect(alreadyInCache).To(Equal(0))

err = blockCache.finishRequest(req1.id)
Expect(err).NotTo(HaveOccurred())
Expand All @@ -363,8 +367,9 @@ var _ = Describe("KV cache", Ordered, func() {
// now 2 and 4 are not in use

// blocks 2 and 4 should be removed, and 5 and 6 stored
err = blockCache.startRequest(req4.id, req4.blocks)
alreadyInCache, err = blockCache.startRequest(req4.id, req4.blocks)
Expect(err).NotTo(HaveOccurred())
Expect(alreadyInCache).To(Equal(0))
}()

removedBlocks := make([]uint64, 0)
Expand Down Expand Up @@ -431,7 +436,7 @@ var _ = Describe("KV cache", Ordered, func() {
reqID := fmt.Sprintf("req_%d_%d", id, j)
blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue)

err := blockCache.startRequest(reqID, blocks)
_, err := blockCache.startRequest(reqID, blocks)
if err != nil {
// some operations may fail due to cache being full, which is expected
Expect(err.Error()).To(Equal(capacityError))
Expand Down
51 changes: 25 additions & 26 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,6 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
if s.config.EnableKVCache && !isChatCompletion {
err := s.kvcacheHelper.OnRequestEnd(vllmReq)
if err != nil {
// TODO should it be an error with http response error or just a warning?
s.logger.Error(err, "kv cache failed to process request end")
}
}
Expand All @@ -391,8 +390,7 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
// kv cache is currently supported for /completion API only
err = s.kvcacheHelper.OnRequestStart(vllmReq)
if err != nil {
// TODO should it be an error with http response error or just a warning?
s.logger.Error(err, "kv cache failed to process request start")
s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(err.Error(), fasthttp.StatusInternalServerError, nil), false)
}
}

Expand Down Expand Up @@ -490,28 +488,22 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
}
s.sendStreamingResponse(
&streamingContext{
ctx: reqCtx.HTTPReqCtx,
isChatCompletion: reqCtx.IsChatCompletion,
model: displayModel,
doRemotePrefill: req.IsDoRemotePrefill(),
ctx: reqCtx.HTTPReqCtx,
isChatCompletion: reqCtx.IsChatCompletion,
model: displayModel,
doRemotePrefill: req.IsDoRemotePrefill(),
nPromptTokens: usageData.PromptTokens,
nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(),
},
usageData.PromptTokens, responseTokens, toolCalls, finishReason, usageDataToSend,
responseTokens, toolCalls, finishReason, usageDataToSend,
)
} else {
if req.IsDoRemoteDecode() {
// in case this is prefill pod processing, return special finish reason
finishReason = common.RemoteDecodeFinishReason
}

s.sendResponse(reqCtx.IsChatCompletion,
reqCtx.HTTPReqCtx,
responseTokens,
toolCalls,
displayModel,
finishReason,
&usageData,
req.IsDoRemoteDecode(),
req.IsDoRemotePrefill())
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
}
}
reqCtx.Wg.Done()
Expand Down Expand Up @@ -628,17 +620,19 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
}

// sendResponse sends response for completion API, supports both completions (text and chat)
// according the value of isChatCompletion
// according the value of isChatCompletion in reqCtx
// respTokens - tokenized content to be sent in the response
// toolCalls - tool calls to be sent in the response
// modelName - display name returned to the client and used in metrics. It is either the first alias
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
// finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools
// usageData - usage (tokens statistics) for this response
func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.RequestCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall,
modelName string, finishReason string, usageData *openaiserverapi.Usage, doRemoteDecode bool, doRemotePrefill bool) {
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, doRemoteDecode)
func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall,
modelName string, finishReason string, usageData *openaiserverapi.Usage) {
resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
reqCtx.CompletionReq.IsDoRemoteDecode())

ctx := reqCtx.HTTPReqCtx
data, err := json.Marshal(resp)
if err != nil {
ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError)
Expand All @@ -647,8 +641,10 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques

// calculate how long to wait before returning the response, time is based on number of tokens
nPromptTokens := usageData.PromptTokens
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
nGenTokens := usageData.CompletionTokens
totalMillisToWait := s.getTimeToFirstToken(nPromptTokens, doRemotePrefill) + s.getTotalInterTokenLatency(nGenTokens)
ttft := s.getTimeToFirstToken(nPromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill())
totalMillisToWait := ttft + s.getTotalInterTokenLatency(nGenTokens)
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)

ctx.Response.Header.SetContentType("application/json")
Expand All @@ -666,7 +662,7 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
}

// returns time to first token based on the current request's doRemotePrefill
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill bool) int {
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int {
if doRemotePrefill {
if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 {
// is disaggregated PD and ttft is calculated using number of prompt tokens
Expand All @@ -677,9 +673,12 @@ func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill b
return int(common.RandomNorm(float64(s.config.KVCacheTransferLatency), float64(s.config.KVCacheTransferLatencyStdDev)))
}
if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 {
// is aggregated PD and ttft is calculated using number of prompt tokens
prefillTime := s.config.PrefillOverhead + nPromptTokens*s.config.PrefillTimePerToken
return int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev)))
// is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache
prefillTime := s.config.PrefillOverhead + (nPromptTokens-nCachedPromptTokens)*s.config.PrefillTimePerToken
fmt.Println("prefillTime ", prefillTime, " float ", float64(prefillTime))
res := int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev)))
fmt.Println("res ", res)
return res
}
// is aggregated PD and *not* using number of prompt tokens
return int(common.RandomNorm(float64(s.config.TimeToFirstToken), float64(s.config.TimeToFirstTokenStdDev)))
Expand Down
63 changes: 45 additions & 18 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ var _ = Describe("Simulator", func() {
simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev
simulator.config.KVCacheTransferLatency = kvCacheLatency
simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev
timeToFirst := simulator.getTimeToFirstToken(1, doREmotePrefill)
timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill)
if doREmotePrefill {
Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3)))
Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7)))
Expand Down Expand Up @@ -838,7 +838,7 @@ var _ = Describe("Simulator", func() {
simulator.config.PrefillTimePerToken = 200
simulator.config.PrefillTimeStdDev = 80

ttft := simulator.getTimeToFirstToken(128, false)
ttft := simulator.getTimeToFirstToken(128, 0, false)

Expect(ttft).To(BeNumerically("==", timeToFirstToken))
})
Expand All @@ -851,33 +851,60 @@ var _ = Describe("Simulator", func() {
simulator.config.PrefillTimePerToken = 200
simulator.config.PrefillTimeStdDev = 80

ttft := simulator.getTimeToFirstToken(128, false)
ttft := simulator.getTimeToFirstToken(128, 0, false)
Expect(ttft).NotTo(BeNumerically("==", 0))
})

DescribeTable("time to first token is against number of prompt tokens",
func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int) {
DescribeTable("time to first token is against number of prompt tokens with std",
func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int, nCachedTokens int) {
simulator.config.TimeToFirstToken = 0
simulator.config.PrefillOverhead = prefillOverhead
simulator.config.PrefillTimePerToken = prefillTimePerToken
simulator.config.PrefillTimeStdDev = stdDev

ttft := simulator.getTimeToFirstToken(nTokens, false)
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)

expectedTTFT := prefillOverhead + prefillTimePerToken*nTokens
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))
Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7)))
},
func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int, nCachedTokens int) string {
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d nCachedTokens: %d",
prefillOverhead, prefillTimePerToken, stdDev, nTokens, nCachedTokens)
},
Entry("single token", 100, 50, 10, 1, 0),
Entry("single token big std", 100, 50, 70, 1, 0),
Entry("stddev is 0", 100, 50, 0, 1, 0),
Entry("medium overhead, 512 tokens", 200, 1000, 150, 512, 0),
Entry("large overhead, 1024 tokens", 2000, 3000, 800, 1024, 0),
Entry("very long prompt", 150, 200, 70, 20000, 0),
Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 150, 512, 256),
Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 800, 1024, 1008),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in name of the test

Entry("very long prompt, 1024 cached", 150, 200, 70, 20000, 1024),
)

DescribeTable("time to first token is against number of prompt tokens",
func(prefillOverhead int, prefillTimePerToken int, nTokens int, nCachedTokens int) {
simulator.config.TimeToFirstToken = 0
simulator.config.PrefillOverhead = prefillOverhead
simulator.config.PrefillTimePerToken = prefillTimePerToken
simulator.config.PrefillTimeStdDev = 0

ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
Expect(ttft).To(Equal(expectedTTFT))
},
func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int) string {
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d",
prefillOverhead, prefillTimePerToken, stdDev, nTokens)
func(prefillOverhead int, prefillTimePerToken, nTokens int, nCachedTokens int) string {
return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, nTokens: %d nCachedTokens: %d",
prefillOverhead, prefillTimePerToken, nTokens, nCachedTokens)
},
Entry("single token", 100, 50, 70, 1),
Entry("stddev is 0", 100, 50, 0, 1),
Entry("medium overhead, 512 tokens", 200, 1000, 150, 512),
Entry("large overhead, 1024 tokens", 2000, 3000, 1800, 1024),
Entry("very long prompt", 150, 200, 100, 20000),
Entry("single token", 100, 50, 1, 0),
Entry("medium overhead, 512 tokens", 200, 1000, 512, 0),
Entry("large overhead, 1024 tokens", 2000, 3000, 1024, 0),
Entry("very long prompt", 150, 200, 20000, 0),
Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 512, 256),
Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 1024, 128),
Entry("very long prompt, 1024 cached", 150, 200, 20000, 1024),
)

It("when <kv-cache-transfer-latency> not 0, ignore <kv-cache-transfer-overhead>", func() {
Expand All @@ -887,7 +914,7 @@ var _ = Describe("Simulator", func() {
simulator.config.KVCacheTransferTimePerToken = 100
simulator.config.KVCacheTransferTimeStdDev = 0

ttft := simulator.getTimeToFirstToken(128, true)
ttft := simulator.getTimeToFirstToken(128, 0, true)
Expect(ttft).To(BeNumerically("==", 200))
})

Expand All @@ -898,7 +925,7 @@ var _ = Describe("Simulator", func() {
simulator.config.KVCacheTransferTimePerToken = 100
simulator.config.KVCacheTransferTimeStdDev = 0

ttft := simulator.getTimeToFirstToken(128, true)
ttft := simulator.getTimeToFirstToken(128, 0, true)
Expect(ttft).To(BeNumerically("==", 12800))
})

Expand All @@ -909,7 +936,7 @@ var _ = Describe("Simulator", func() {
simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT
simulator.config.KVCacheTransferTimeStdDev = stddev

ttft := simulator.getTimeToFirstToken(nTokens, true)
ttft := simulator.getTimeToFirstToken(nTokens, 0, true)

expectedTTFT := kvCacheTransTPT * nTokens
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))
Expand Down
Loading