Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"regexp"
"sync"

"github.com/go-logr/logr"
"github.com/google/uuid"
)

Expand Down Expand Up @@ -127,3 +128,11 @@ func init() {
func Tokenize(text string) []string {
return re.FindAllString(text, -1)
}

func WriteToChannel[T any](channel chan T, object T, logger logr.Logger, channelName string) {
select {
case channel <- object:
default:
logger.V(1).Info("failed to write to", "channel", channelName)
}
}
20 changes: 13 additions & 7 deletions pkg/kv-cache/block_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ func newBlockCache(config *common.Configuration, logger logr.Logger, usageChan c
}, nil
}

func (b *blockCache) start(ctx context.Context) {
err := b.eventSender.Run(ctx)
func (bc *blockCache) start(ctx context.Context) {
err := bc.eventSender.Run(ctx)
if err != nil {
b.logger.Info("sender stopped with error", "error", err)
bc.logger.Info("sender stopped with error", "error", err)
}
}

Expand Down Expand Up @@ -139,20 +139,25 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, erro
}

delete(bc.unusedBlocks, oldestUnusedHash)
bc.eventChan <- EventData{action: eventActionRemove, hashValues: []uint64{oldestUnusedHash}}
common.WriteToChannel(bc.eventChan,
EventData{action: eventActionRemove, hashValues: []uint64{oldestUnusedHash}},
bc.logger, "block cache eventChan")
}

// Add the new block
bc.usedBlocks[block] = 1
bc.eventChan <- EventData{action: eventActionStore, hashValues: []uint64{block}}
common.WriteToChannel(bc.eventChan,
EventData{action: eventActionStore, hashValues: []uint64{block}},
bc.logger, "block cache eventChan")
}

// store the request mapping
bc.requestToBlocks[requestID] = make([]uint64, len(blocks))
copy(bc.requestToBlocks[requestID], blocks)

if bc.usageChan != nil {
bc.usageChan <- float64(len(bc.usedBlocks)) / float64(bc.maxBlocks)
common.WriteToChannel(bc.usageChan, float64(len(bc.usedBlocks))/float64(bc.maxBlocks),
bc.logger, "block cache usageChan")
}
return len(blockAreadyInUse) + len(blockToMoveToUsed), nil
}
Expand Down Expand Up @@ -188,7 +193,8 @@ func (bc *blockCache) finishRequest(requestID string) error {
}

if bc.usageChan != nil {
bc.usageChan <- float64(len(bc.usedBlocks)) / float64(bc.maxBlocks)
common.WriteToChannel(bc.usageChan, float64(len(bc.usedBlocks))/float64(bc.maxBlocks),
bc.logger, "block cache usageChan")
}

// Remove the request mapping
Expand Down
3 changes: 2 additions & 1 deletion pkg/llm-d-inference-sim/lora.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package llmdinferencesim
import (
"encoding/json"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
"github.com/valyala/fasthttp"
)

Expand Down Expand Up @@ -139,6 +140,6 @@ func (s *VllmSimulator) decrementLora(model string) {
s.loras.loadedLoras[model]--
if s.loras.loadedLoras[model] <= 0 {
// last usage of this LoRA
s.loras.loraRemovable <- 1
common.WriteToChannel(s.loras.loraRemovable, 1, s.logger, "loraRemovable")
}
}
23 changes: 13 additions & 10 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func (s *VllmSimulator) processing(ctx context.Context) {

s.logger.V(4).Info("Sending the request to the processing channel", "model", model,
"req id", reqCtx.CompletionReq.GetRequestID(), "worker", worker.id)
worker.reqChan <- reqCtx
common.WriteToChannel(worker.reqChan, reqCtx, s.logger, "worker's reqChan")
}
}
}
Expand All @@ -431,9 +431,9 @@ func (s *VllmSimulator) findRequestAndSendToProcess(worker *worker) bool {
// send this request for processing in this worker
s.logger.V(4).Info("Sending request to processing", "model", nextReq.CompletionReq.GetModel(),
"req", nextReq.CompletionReq.GetRequestID(), "worker", worker.id)
worker.reqChan <- nextReq
common.WriteToChannel(worker.reqChan, nextReq, s.logger, "worker's reqChan")
// decrement waiting requests metric
s.metrics.waitingReqChan <- -1
common.WriteToChannel(s.metrics.waitingReqChan, -1, s.logger, "metrics.waitingReqChan")
return true
}

Expand All @@ -450,9 +450,11 @@ func (s *VllmSimulator) addRequestToQueue(reqCtx *openaiserverapi.CompletionReqC
return
}
// increment the waiting requests metric
s.metrics.waitingReqChan <- 1
common.WriteToChannel(s.metrics.waitingReqChan, 1, s.logger, "metrics.waitingReqChan")
// update loraInfo metrics with the new waiting request
s.metrics.lorasChan <- loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState}
common.WriteToChannel(s.metrics.lorasChan, loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState},
s.logger, "metrics.lorasChan")

}

// handleCompletions general completion requests handler, support both text and chat completion APIs
Expand Down Expand Up @@ -487,18 +489,19 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
IsChatCompletion: isChatCompletion,
Wg: &wg,
}
s.newRequests <- reqCtx
common.WriteToChannel(s.newRequests, reqCtx, s.logger, "newRequests")
wg.Wait()
}

// request processing finished
func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool, requestID string) {
// decrement running requests count
s.metrics.runReqChan <- -1
common.WriteToChannel(s.metrics.runReqChan, -1, s.logger, "metrics.runReqChan")

if s.isLora(model) {
// update loraInfo metrics to reflect that the request processing has been finished
s.metrics.lorasChan <- loraUsage{model, doneUsageState}
common.WriteToChannel(s.metrics.lorasChan, loraUsage{model, doneUsageState},
s.logger, "metrics.lorasChan")
}

if s.config.EnableKVCache && !isChatCompletion {
Expand Down Expand Up @@ -580,14 +583,14 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
time.Sleep(time.Duration(ttft) * time.Millisecond)

// report ttft in seconds
s.metrics.ttftChan <- (float64(ttft) / 1000)
common.WriteToChannel(s.metrics.ttftChan, (float64(ttft) / 1000), s.logger, "metrics.ttftChan")

for range usageData.CompletionTokens - 1 {
perTokenLatency := s.getInterTokenLatency()
time.Sleep(time.Duration(perTokenLatency) * time.Millisecond)

// report tpot in seconds
s.metrics.tpotChan <- float64(perTokenLatency) / 1000
common.WriteToChannel(s.metrics.tpotChan, (float64(perTokenLatency) / 1000), s.logger, "metrics.tpotChan")
}
s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp)

Expand Down
4 changes: 2 additions & 2 deletions pkg/llm-d-inference-sim/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
ttft := s.getWaitTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill)
time.Sleep(time.Duration(ttft) * time.Millisecond)
// report ttft in seconds
s.metrics.ttftChan <- (float64(ttft) / 1000)
common.WriteToChannel(s.metrics.ttftChan, (float64(ttft) / 1000), s.logger, "metrics.ttftChan")

for i, token := range genTokens {
if i != 0 {
interTokenLat := s.getInterTokenLatency()
time.Sleep(time.Duration(interTokenLat) * time.Millisecond)
// report tpot in seconds
s.metrics.tpotChan <- float64(interTokenLat) / 1000
common.WriteToChannel(s.metrics.tpotChan, (float64(interTokenLat) / 1000), s.logger, "metrics.tpotChan")
}

var toolChunkInsert *openaiserverapi.ToolCall
Expand Down
23 changes: 11 additions & 12 deletions pkg/llm-d-inference-sim/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"

"github.com/go-logr/logr"
"github.com/llm-d/llm-d-inference-sim/pkg/common"
"github.com/llm-d/llm-d-inference-sim/pkg/dataset"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
"github.com/valyala/fasthttp"
Expand Down Expand Up @@ -63,12 +64,13 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
displayModel := s.getDisplayedModelName(model)

// increment running requests count
s.metrics.runReqChan <- 1
common.WriteToChannel(s.metrics.runReqChan, 1, s.logger, "metrics.runReqChan")

if s.isLora(model) {
// update loraInfo metric to reflect that
// the request has changed its status from waiting to running
s.metrics.lorasChan <- loraUsage{model, runningUsageState}
common.WriteToChannel(s.metrics.lorasChan, loraUsage{model, runningUsageState}, s.logger,
"metrics.lorasChan")
}

if s.config.EnableKVCache && !reqCtx.IsChatCompletion {
Expand Down Expand Up @@ -137,16 +139,13 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
}

select {
case s.metrics.requestSuccessChan <- requestSuccessEvent{
promptTokens: usageData.PromptTokens,
generationTokens: usageData.CompletionTokens,
maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(),
finishReason: finishReason,
}:
default:
s.logger.V(1).Info("requestSuccessChan full, dropping success event")
}
common.WriteToChannel(s.metrics.requestSuccessChan,
requestSuccessEvent{
promptTokens: usageData.PromptTokens,
generationTokens: usageData.CompletionTokens,
maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(),
finishReason: finishReason},
s.logger, "metrics.requestSuccessChan")
}

s.logger.V(4).Info("Finished processing request", "id", req.GetRequestID())
Expand Down