Skip to content

Commit 64d8d7f

Browse files
authored
Make writing to channels non-blocking (#225)
* Made writing to channels non-blocking Signed-off-by: irar2 <[email protected]> * Lint Signed-off-by: irar2 <[email protected]> --------- Signed-off-by: irar2 <[email protected]>
1 parent 61c1c29 commit 64d8d7f

File tree

6 files changed

+50
-32
lines changed

6 files changed

+50
-32
lines changed

pkg/common/utils.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"regexp"
2222
"sync"
2323

24+
"github.com/go-logr/logr"
2425
"github.com/google/uuid"
2526
)
2627

@@ -127,3 +128,11 @@ func init() {
127128
func Tokenize(text string) []string {
128129
return re.FindAllString(text, -1)
129130
}
131+
132+
func WriteToChannel[T any](channel chan T, object T, logger logr.Logger, channelName string) {
133+
select {
134+
case channel <- object:
135+
default:
136+
logger.V(1).Info("failed to write to", "channel", channelName)
137+
}
138+
}

pkg/kv-cache/block_cache.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ func newBlockCache(config *common.Configuration, logger logr.Logger, usageChan c
7070
}, nil
7171
}
7272

73-
func (b *blockCache) start(ctx context.Context) {
74-
err := b.eventSender.Run(ctx)
73+
func (bc *blockCache) start(ctx context.Context) {
74+
err := bc.eventSender.Run(ctx)
7575
if err != nil {
76-
b.logger.Info("sender stopped with error", "error", err)
76+
bc.logger.Info("sender stopped with error", "error", err)
7777
}
7878
}
7979

@@ -139,20 +139,25 @@ func (bc *blockCache) startRequest(requestID string, blocks []uint64) (int, erro
139139
}
140140

141141
delete(bc.unusedBlocks, oldestUnusedHash)
142-
bc.eventChan <- EventData{action: eventActionRemove, hashValues: []uint64{oldestUnusedHash}}
142+
common.WriteToChannel(bc.eventChan,
143+
EventData{action: eventActionRemove, hashValues: []uint64{oldestUnusedHash}},
144+
bc.logger, "block cache eventChan")
143145
}
144146

145147
// Add the new block
146148
bc.usedBlocks[block] = 1
147-
bc.eventChan <- EventData{action: eventActionStore, hashValues: []uint64{block}}
149+
common.WriteToChannel(bc.eventChan,
150+
EventData{action: eventActionStore, hashValues: []uint64{block}},
151+
bc.logger, "block cache eventChan")
148152
}
149153

150154
// store the request mapping
151155
bc.requestToBlocks[requestID] = make([]uint64, len(blocks))
152156
copy(bc.requestToBlocks[requestID], blocks)
153157

154158
if bc.usageChan != nil {
155-
bc.usageChan <- float64(len(bc.usedBlocks)) / float64(bc.maxBlocks)
159+
common.WriteToChannel(bc.usageChan, float64(len(bc.usedBlocks))/float64(bc.maxBlocks),
160+
bc.logger, "block cache usageChan")
156161
}
157162
return len(blockAreadyInUse) + len(blockToMoveToUsed), nil
158163
}
@@ -188,7 +193,8 @@ func (bc *blockCache) finishRequest(requestID string) error {
188193
}
189194

190195
if bc.usageChan != nil {
191-
bc.usageChan <- float64(len(bc.usedBlocks)) / float64(bc.maxBlocks)
196+
common.WriteToChannel(bc.usageChan, float64(len(bc.usedBlocks))/float64(bc.maxBlocks),
197+
bc.logger, "block cache usageChan")
192198
}
193199

194200
// Remove the request mapping

pkg/llm-d-inference-sim/lora.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package llmdinferencesim
2020
import (
2121
"encoding/json"
2222

23+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
2324
"github.com/valyala/fasthttp"
2425
)
2526

@@ -139,6 +140,6 @@ func (s *VllmSimulator) decrementLora(model string) {
139140
s.loras.loadedLoras[model]--
140141
if s.loras.loadedLoras[model] <= 0 {
141142
// last usage of this LoRA
142-
s.loras.loraRemovable <- 1
143+
common.WriteToChannel(s.loras.loraRemovable, 1, s.logger, "loraRemovable")
143144
}
144145
}

pkg/llm-d-inference-sim/simulator.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ func (s *VllmSimulator) processing(ctx context.Context) {
420420

421421
s.logger.V(4).Info("Sending the request to the processing channel", "model", model,
422422
"req id", reqCtx.CompletionReq.GetRequestID(), "worker", worker.id)
423-
worker.reqChan <- reqCtx
423+
common.WriteToChannel(worker.reqChan, reqCtx, s.logger, "worker's reqChan")
424424
}
425425
}
426426
}
@@ -431,9 +431,9 @@ func (s *VllmSimulator) findRequestAndSendToProcess(worker *worker) bool {
431431
// send this request for processing in this worker
432432
s.logger.V(4).Info("Sending request to processing", "model", nextReq.CompletionReq.GetModel(),
433433
"req", nextReq.CompletionReq.GetRequestID(), "worker", worker.id)
434-
worker.reqChan <- nextReq
434+
common.WriteToChannel(worker.reqChan, nextReq, s.logger, "worker's reqChan")
435435
// decrement waiting requests metric
436-
s.metrics.waitingReqChan <- -1
436+
common.WriteToChannel(s.metrics.waitingReqChan, -1, s.logger, "metrics.waitingReqChan")
437437
return true
438438
}
439439

@@ -450,9 +450,11 @@ func (s *VllmSimulator) addRequestToQueue(reqCtx *openaiserverapi.CompletionReqC
450450
return
451451
}
452452
// increment the waiting requests metric
453-
s.metrics.waitingReqChan <- 1
453+
common.WriteToChannel(s.metrics.waitingReqChan, 1, s.logger, "metrics.waitingReqChan")
454454
// update loraInfo metrics with the new waiting request
455-
s.metrics.lorasChan <- loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState}
455+
common.WriteToChannel(s.metrics.lorasChan, loraUsage{reqCtx.CompletionReq.GetModel(), waitingUsageState},
456+
s.logger, "metrics.lorasChan")
457+
456458
}
457459

458460
// handleCompletions general completion requests handler, support both text and chat completion APIs
@@ -487,18 +489,19 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
487489
IsChatCompletion: isChatCompletion,
488490
Wg: &wg,
489491
}
490-
s.newRequests <- reqCtx
492+
common.WriteToChannel(s.newRequests, reqCtx, s.logger, "newRequests")
491493
wg.Wait()
492494
}
493495

494496
// request processing finished
495497
func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool, requestID string) {
496498
// decrement running requests count
497-
s.metrics.runReqChan <- -1
499+
common.WriteToChannel(s.metrics.runReqChan, -1, s.logger, "metrics.runReqChan")
498500

499501
if s.isLora(model) {
500502
// update loraInfo metrics to reflect that the request processing has been finished
501-
s.metrics.lorasChan <- loraUsage{model, doneUsageState}
503+
common.WriteToChannel(s.metrics.lorasChan, loraUsage{model, doneUsageState},
504+
s.logger, "metrics.lorasChan")
502505
}
503506

504507
if s.config.EnableKVCache && !isChatCompletion {
@@ -580,14 +583,14 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
580583
time.Sleep(time.Duration(ttft) * time.Millisecond)
581584

582585
// report ttft in seconds
583-
s.metrics.ttftChan <- (float64(ttft) / 1000)
586+
common.WriteToChannel(s.metrics.ttftChan, (float64(ttft) / 1000), s.logger, "metrics.ttftChan")
584587

585588
for range usageData.CompletionTokens - 1 {
586589
perTokenLatency := s.getInterTokenLatency()
587590
time.Sleep(time.Duration(perTokenLatency) * time.Millisecond)
588591

589592
// report tpot in seconds
590-
s.metrics.tpotChan <- float64(perTokenLatency) / 1000
593+
common.WriteToChannel(s.metrics.tpotChan, (float64(perTokenLatency) / 1000), s.logger, "metrics.tpotChan")
591594
}
592595
s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp)
593596

pkg/llm-d-inference-sim/streaming.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
104104
ttft := s.getWaitTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill)
105105
time.Sleep(time.Duration(ttft) * time.Millisecond)
106106
// report ttft in seconds
107-
s.metrics.ttftChan <- (float64(ttft) / 1000)
107+
common.WriteToChannel(s.metrics.ttftChan, (float64(ttft) / 1000), s.logger, "metrics.ttftChan")
108108

109109
for i, token := range genTokens {
110110
if i != 0 {
111111
interTokenLat := s.getInterTokenLatency()
112112
time.Sleep(time.Duration(interTokenLat) * time.Millisecond)
113113
// report tpot in seconds
114-
s.metrics.tpotChan <- float64(interTokenLat) / 1000
114+
common.WriteToChannel(s.metrics.tpotChan, (float64(interTokenLat) / 1000), s.logger, "metrics.tpotChan")
115115
}
116116

117117
var toolChunkInsert *openaiserverapi.ToolCall

pkg/llm-d-inference-sim/worker.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"context"
2222

2323
"github.com/go-logr/logr"
24+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
2425
"github.com/llm-d/llm-d-inference-sim/pkg/dataset"
2526
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
2627
"github.com/valyala/fasthttp"
@@ -63,12 +64,13 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
6364
displayModel := s.getDisplayedModelName(model)
6465

6566
// increment running requests count
66-
s.metrics.runReqChan <- 1
67+
common.WriteToChannel(s.metrics.runReqChan, 1, s.logger, "metrics.runReqChan")
6768

6869
if s.isLora(model) {
6970
// update loraInfo metric to reflect that
7071
// the request has changed its status from waiting to running
71-
s.metrics.lorasChan <- loraUsage{model, runningUsageState}
72+
common.WriteToChannel(s.metrics.lorasChan, loraUsage{model, runningUsageState}, s.logger,
73+
"metrics.lorasChan")
7274
}
7375

7476
if s.config.EnableKVCache && !reqCtx.IsChatCompletion {
@@ -137,16 +139,13 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
137139
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
138140
}
139141

140-
select {
141-
case s.metrics.requestSuccessChan <- requestSuccessEvent{
142-
promptTokens: usageData.PromptTokens,
143-
generationTokens: usageData.CompletionTokens,
144-
maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(),
145-
finishReason: finishReason,
146-
}:
147-
default:
148-
s.logger.V(1).Info("requestSuccessChan full, dropping success event")
149-
}
142+
common.WriteToChannel(s.metrics.requestSuccessChan,
143+
requestSuccessEvent{
144+
promptTokens: usageData.PromptTokens,
145+
generationTokens: usageData.CompletionTokens,
146+
maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(),
147+
finishReason: finishReason},
148+
s.logger, "metrics.requestSuccessChan")
150149
}
151150

152151
s.logger.V(4).Info("Finished processing request", "id", req.GetRequestID())

0 commit comments

Comments
 (0)