Skip to content

Commit 658e3e5

Browse files
authored
Add synchronization of freeing worker after stream reqiest processing (#244)
* add synchronization of freeing worker after stream reqiest processing Signed-off-by: Maya Barnea <[email protected]> * additioal changes which fix e2e request latency and inference time calculations for requests in streaming mode Signed-off-by: Maya Barnea <[email protected]> --------- Signed-off-by: Maya Barnea <[email protected]>
1 parent 3967e23 commit 658e3e5

File tree

5 files changed

+62
-19
lines changed

5 files changed

+62
-19
lines changed

pkg/llm-d-inference-sim/simulator.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,8 @@ func (s *VllmSimulator) addRequestToQueue(reqCtx *openaiserverapi.CompletionReqC
491491
}
492492

493493
// handleCompletions general completion requests handler, support both text and chat completion APIs
494+
// Importan note: for requests in streaming mode, this function exists before all chunk are sent to the client
494495
func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) {
495-
startTime := time.Now()
496-
defer func() {
497-
common.WriteToChannel(s.metrics.e2eReqLatencyChan, time.Since(startTime).Seconds(), s.logger, "metrics.e2eReqLatencyChan")
498-
}()
499-
500496
// Check if we should inject a failure
501497
if shouldInjectFailure(s.config, s.random) {
502498
failure := getRandomFailure(s.config, s.random)
@@ -526,6 +522,7 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
526522
HTTPReqCtx: ctx,
527523
IsChatCompletion: isChatCompletion,
528524
Wg: &wg,
525+
StartProcessing: time.Now(),
529526
}
530527
common.WriteToChannel(s.newRequests, reqCtx, s.logger, "newRequests")
531528
wg.Wait()

pkg/llm-d-inference-sim/streaming.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"strconv"
24+
"sync"
2425
"time"
2526

2627
"github.com/llm-d/llm-d-inference-sim/pkg/common"
@@ -47,7 +48,7 @@ type streamingContext struct {
4748
// response content is wrapped according SSE format
4849
// First token is send after timeToFirstToken milliseconds, every other token is sent after interTokenLatency milliseconds
4950
func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, responseTokens []string, toolCalls []openaiserverapi.ToolCall,
50-
finishReason string, usageData *openaiserverapi.Usage) {
51+
finishReason string, usageData *openaiserverapi.Usage, wg *sync.WaitGroup) {
5152
context.ctx.SetContentType("text/event-stream")
5253
context.ctx.SetStatusCode(fasthttp.StatusOK)
5354

@@ -78,8 +79,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
7879
s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason)
7980
}
8081
} else {
81-
s.logger.Info("Going to send text", "number of tokens", len(responseTokens))
82+
s.logger.V(4).Info("Going to send text", "number of tokens", len(responseTokens))
8283
s.sendTokenChunks(context, w, responseTokens, nil, finishReason)
84+
s.logger.V(4).Info("Finished sending text", "number of tokens", len(responseTokens))
8385
}
8486
}
8587

@@ -98,6 +100,7 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
98100
return
99101
}
100102
s.responseSentCallback(context.model, context.isChatCompletion, context.requestID)
103+
wg.Done()
101104
})
102105
}
103106

pkg/llm-d-inference-sim/test_utils.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,25 @@ func startServerAndSendRequest(modelName string, prompt string, isStreaming bool
151151
client, err := startServerWithArgs(ctx, args)
152152
gomega.Expect(err).NotTo(gomega.HaveOccurred())
153153

154-
openaiclient, params := getOpenAIClientAndChatParams(client, modelName, prompt, isStreaming)
154+
openaitextclient, params := getOpenAIClientAndTextParams(client, modelName, prompt, isStreaming)
155155

156-
// send a single request in a serial way
157-
_, err = openaiclient.Chat.Completions.New(ctx, params)
158-
gomega.Expect(err).NotTo(gomega.HaveOccurred())
156+
if isStreaming {
157+
// send a single request in a serial way
158+
stream := openaitextclient.Completions.NewStreaming(ctx, params)
159+
chunksCnt := 0
160+
161+
for stream.Next() {
162+
chunksCnt++
163+
}
164+
if err := stream.Err(); err != nil {
165+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
166+
}
167+
// number of chunks is number of tokens + 2 (one chunk with usage info and one closing chunk)
168+
gomega.Expect(chunksCnt).To(gomega.BeNumerically("==", len(common.Tokenize(prompt))+2))
169+
} else {
170+
_, err = openaitextclient.Completions.New(ctx, params)
171+
gomega.Expect(err).NotTo(gomega.HaveOccurred())
172+
}
159173

160174
return client
161175
}
@@ -198,6 +212,22 @@ func getOpenAIClientAndChatParams(client option.HTTPClient, model string, messag
198212
return openaiclient, params
199213
}
200214

215+
// getOpenAIClientAndTextParams - creates an openai client and params for /completions call based on the given parameters
216+
func getOpenAIClientAndTextParams(client option.HTTPClient, model string, message string, streaming bool) (openai.Client, openai.CompletionNewParams) {
217+
openaiclient := openai.NewClient(
218+
option.WithBaseURL(baseURL),
219+
option.WithHTTPClient(client))
220+
221+
params := openai.CompletionNewParams{
222+
Prompt: openai.CompletionNewParamsPromptUnion{OfString: param.Opt[string]{Value: message}},
223+
Model: openai.CompletionNewParamsModel(model),
224+
}
225+
if streaming {
226+
params.StreamOptions = openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}
227+
}
228+
return openaiclient, params
229+
}
230+
201231
// nolint
202232
// getOpenAIClentAndCompletionParams - creates an openai client and params for /completions call based on the given parameters
203233
func getOpenAIClentAndCompletionParams(client option.HTTPClient, model string, message string,

pkg/llm-d-inference-sim/worker.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package llmdinferencesim
1919

2020
import (
2121
"context"
22+
"sync"
2223
"time"
2324

2425
"github.com/go-logr/logr"
@@ -49,22 +50,31 @@ func (w *worker) waitForRequests() {
4950
w.logger.V(4).Info("worker done", "id", w.id)
5051
return
5152
case req := <-w.reqChan:
52-
w.processor.processRequest(req)
53+
w.processor.processRequest(req, nil)
5354
w.finishedChan <- &requestCompleted{worker: w, model: req.CompletionReq.GetModel()}
5455
}
56+
5557
}
5658
}
5759

5860
type requestProcessor interface {
59-
processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
61+
processRequest(reqCtx *openaiserverapi.CompletionReqCtx, wg *sync.WaitGroup)
6062
}
6163

62-
func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) {
63-
start := time.Now()
64-
defer func() {
65-
common.WriteToChannel(s.metrics.reqInferenceTimeChan, time.Since(start).Seconds(), s.logger, "metrics.reqInferenceTimeChan")
66-
}()
64+
func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx, _ *sync.WaitGroup) {
65+
startTime := time.Now()
66+
wg := sync.WaitGroup{}
67+
wg.Add(1)
68+
69+
go s.processRequestAsync(reqCtx, &wg)
70+
71+
wg.Wait()
72+
// calculate inference time and finish e2e latency calculation only when sure that request processing was finished for streaming requests too
73+
common.WriteToChannel(s.metrics.e2eReqLatencyChan, time.Since(reqCtx.StartProcessing).Seconds(), s.logger, "metrics.e2eReqLatencyChan")
74+
common.WriteToChannel(s.metrics.reqInferenceTimeChan, time.Since(startTime).Seconds(), s.logger, "metrics.reqInferenceTimeChan")
75+
}
6776

77+
func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionReqCtx, wg *sync.WaitGroup) {
6878
req := reqCtx.CompletionReq
6979
model := req.GetModel()
7080
displayModel := s.getDisplayedModelName(model)
@@ -138,14 +148,15 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx)
138148
// Logprobs configuration
139149
logprobs: req.GetLogprobs(),
140150
},
141-
responseTokens, toolCalls, finishReason, usageDataToSend,
151+
responseTokens, toolCalls, finishReason, usageDataToSend, wg,
142152
)
143153
} else {
144154
if req.IsDoRemoteDecode() {
145155
// in case this is prefill pod processing, return special finish reason
146156
finishReason = dataset.RemoteDecodeFinishReason
147157
}
148158
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
159+
wg.Done()
149160
}
150161

151162
common.WriteToChannel(s.metrics.requestSuccessChan,

pkg/openai-server-api/request.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package openaiserverapi
1919

2020
import (
2121
"sync"
22+
"time"
2223

2324
"github.com/valyala/fasthttp"
2425
)
@@ -163,6 +164,7 @@ type CompletionReqCtx struct {
163164
HTTPReqCtx *fasthttp.RequestCtx
164165
IsChatCompletion bool
165166
Wg *sync.WaitGroup
167+
StartProcessing time.Time
166168
}
167169

168170
// ChatCompletionRequest defines structure of /chat/completion request

0 commit comments

Comments
 (0)