Skip to content

Commit 574342e

Browse files
committed
Move token generation to simulator
Signed-off-by: Qifan Deng <[email protected]>
1 parent b766507 commit 574342e

File tree

3 files changed

+38
-42
lines changed

3 files changed

+38
-42
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ vendor
88
*.test
99
manifests/dev-config.yaml
1010
pkg/llm-d-inference-sim/.llm-d
11-
.llm-d/
11+
pkg/llm-d-inference-sim/tests-tmp/

pkg/llm-d-inference-sim/simulator.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
339339
if toolCalls == nil && err == nil {
340340
// Either no tool calls were defined, or we randomly chose not to create tool calls,
341341
// so we generate a response text.
342-
responseTokens, finishReason, completionTokens, err = req.CreateResponseText(s.config.Mode)
342+
responseTokens, finishReason, completionTokens, err = s.generateTokens(req)
343343
}
344344
if err != nil {
345345
prefix := ""
@@ -514,8 +514,6 @@ func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse {
514514

515515
return &modelsResp
516516
}
517-
<<<<<<< HEAD
518-
=======
519517

520518
// HandleHealth http handler for /health
521519
func (s *VllmSimulator) HandleHealth(ctx *fasthttp.RequestCtx) {
@@ -599,4 +597,38 @@ func (s *VllmSimulator) GetPrefillTimePerToken() int {
599597
func (s *VllmSimulator) GetInterTokenLatency() int {
600598
return int(float64(s.config.InterTokenLatency) * s.getCurrFactor())
601599
}
602-
>>>>>>> 482434e (Show config in yaml)
600+
601+
// generateTokens creates and returns response payload based on this request,
602+
// i.e., an array of generated tokens, the finish reason, and the number of created tokens
603+
func (s *VllmSimulator) generateTokens(req openaiserverapi.CompletionRequest) ([]string, string, int, error) {
604+
// if req is ChatCompletionRequest
605+
ignoreEOS := req.GetIgnoreEOS()
606+
var maxTokens *int64
607+
var prompt string
608+
609+
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
610+
maxTokens = chatReq.GetMaxCompletionTokens()
611+
prompt = chatReq.GetLastUserMsg()
612+
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
613+
maxTokens = textReq.MaxTokens
614+
prompt = textReq.GetPrompt()
615+
} else {
616+
return nil, "", 0, fmt.Errorf("unknown request type: %T", req)
617+
}
618+
619+
maxTokensValue, err := common.GetMaxTokens(nil, maxTokens)
620+
if err != nil {
621+
return nil, "", 0, err
622+
}
623+
624+
var text, finishReason string
625+
if s.config.Mode == common.ModeEcho {
626+
text, finishReason = common.GetResponseText(maxTokensValue, prompt)
627+
} else {
628+
text, finishReason = common.GetRandomResponseText(maxTokensValue, ignoreEOS)
629+
}
630+
631+
tokens := common.Tokenize(text)
632+
return tokens, finishReason, len(tokens), nil
633+
}
634+
>>>>>>> 48ec8bc (Move token generation to simulator)

pkg/openai-server-api/request.go

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ const (
3333
type CompletionRequest interface {
3434
// GetRequestID returns the unique request id
3535
GetRequestID() string
36-
// CreateResponseText creates and returns response payload based on this request,
37-
// i.e., an array of generated tokens, the finish reason, and the number of created
38-
// tokens
39-
CreateResponseText(mode string) ([]string, string, int, error)
4036
// IsStream returns boolean that defines is response should be streamed
4137
IsStream() bool
4238
// GetModel returns model name as defined in the request
@@ -230,7 +226,7 @@ func (c *ChatCompletionRequest) GetMaxCompletionTokens() *int64 {
230226

231227
// getLastUserMsg returns last message from this request's messages with user role,
232228
// if does not exist - returns an empty string
233-
func (req *ChatCompletionRequest) getLastUserMsg() string {
229+
func (req *ChatCompletionRequest) GetLastUserMsg() string {
234230
for i := len(req.Messages) - 1; i >= 0; i-- {
235231
if req.Messages[i].Role == RoleUser {
236232
return req.Messages[i].Content.PlainText()
@@ -240,31 +236,6 @@ func (req *ChatCompletionRequest) getLastUserMsg() string {
240236
return ""
241237
}
242238

243-
// CreateResponseText creates and returns response payload based on this request,
244-
// i.e., an array of generated tokens, the finish reason, and the number of created
245-
// tokens
246-
func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) {
247-
return generateResponseText(mode, req.GetMaxCompletionTokens(), req.getLastUserMsg(), req.GetIgnoreEOS())
248-
}
249-
250-
// Helper function to generate response text
251-
func generateResponseText(mode string, maxTokens *int64, prompt string, ignoreEOS bool) ([]string, string, int, error) {
252-
maxTokensValue, err := common.GetMaxTokens(nil, maxTokens)
253-
if err != nil {
254-
return nil, "", 0, err
255-
}
256-
257-
var text, finishReason string
258-
if mode == common.ModeEcho {
259-
text, finishReason = common.GetResponseText(maxTokensValue, prompt)
260-
} else {
261-
text, finishReason = common.GetRandomResponseText(maxTokensValue, ignoreEOS)
262-
}
263-
264-
tokens := common.Tokenize(text)
265-
return tokens, finishReason, len(tokens), nil
266-
}
267-
268239
// v1/completion
269240
// TextCompletionRequest defines structure of /completion request
270241
type TextCompletionRequest struct {
@@ -299,10 +270,3 @@ func (c *TextCompletionRequest) GetToolChoice() string {
299270
func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 {
300271
return c.MaxTokens
301272
}
302-
303-
// CreateResponseText creates and returns response payload based on this request,
304-
// i.e., an array of generated tokens, the finish reason, and the number of created
305-
// tokens
306-
func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) {
307-
return generateResponseText(mode, req.MaxTokens, req.Prompt, req.GetIgnoreEOS())
308-
}

0 commit comments

Comments
 (0)