Skip to content

Commit 2b4a79a

Browse files
authored
Support long responses and additional fixes (#104)
* - Use tokenize function which divide text by space and additional characters in request processing too (not only in tools related part) - Validate max_token and max_completion_token as request arrived - Protect generating any random value with mutex - Fix test for the changes above + add test for random texts creation Signed-off-by: Maya Barnea <[email protected]> * fix lint problem according the PR comment Signed-off-by: Maya Barnea <[email protected]> * restore tests that check validity of returned response text, check that it could be built from the predefined parts Signed-off-by: Maya Barnea <[email protected]> * fixed typo in comment Signed-off-by: Maya Barnea <[email protected]> --------- Signed-off-by: Maya Barnea <[email protected]>
1 parent 7f1f766 commit 2b4a79a

File tree

7 files changed

+249
-62
lines changed

7 files changed

+249
-62
lines changed

pkg/llm-d-inference-sim/request.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
package llmdinferencesim
1919

2020
import (
21-
"strings"
2221
"sync"
2322

2423
"github.com/valyala/fasthttp"
@@ -158,7 +157,7 @@ func (c *chatCompletionRequest) getNumberOfPromptTokens() int {
158157
for _, message := range c.Messages {
159158
messages += message.Content.PlainText() + " "
160159
}
161-
return len(strings.Fields(messages))
160+
return len(tokenize(messages))
162161
}
163162

164163
func (c *chatCompletionRequest) getTools() []tool {
@@ -224,7 +223,7 @@ type textCompletionRequest struct {
224223
}
225224

226225
func (t *textCompletionRequest) getNumberOfPromptTokens() int {
227-
return len(strings.Fields(t.Prompt))
226+
return len(tokenize(t.Prompt))
228227
}
229228

230229
func (c *textCompletionRequest) getTools() []tool {

pkg/llm-d-inference-sim/seed_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ var _ = Describe("Simulator with seed", func() {
4242
Prompt: openai.CompletionNewParamsPromptUnion{
4343
OfString: openai.String(userMessage),
4444
},
45-
Model: openai.CompletionNewParamsModel(model),
45+
Model: openai.CompletionNewParamsModel(model),
46+
MaxTokens: openai.Int(10),
4647
}
4748

4849
resp, err := openaiclient.Completions.New(ctx, params)

pkg/llm-d-inference-sim/simulator.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,10 @@ func (s *VllmSimulator) validateRequest(req completionRequest) (string, string,
354354
return fmt.Sprintf("The model `%s` does not exist.", req.getModel()), "NotFoundError", fasthttp.StatusNotFound
355355
}
356356

357+
if req.getMaxCompletionTokens() != nil && *req.getMaxCompletionTokens() <= 0 {
358+
return "Max completion tokens and max tokens should be positive", "Invalid request", fasthttp.StatusBadRequest
359+
}
360+
357361
if req.doRemoteDecode() && req.isStream() {
358362
return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest
359363
}

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import (
3838
const model = "my_model"
3939
const baseURL = "http://localhost/v1"
4040
const userMessage = "This is a test."
41+
const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be positive"
42+
43+
var userMsgTokens int64
4144

4245
func startServer(ctx context.Context, mode string) (*http.Client, error) {
4346
return startServerWithArgs(ctx, mode, nil)
@@ -65,6 +68,10 @@ func startServerWithArgs(ctx context.Context, mode string, args []string) (*http
6568
return nil, err
6669
}
6770

71+
// calculate number of tokens for user message,
72+
// must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine
73+
userMsgTokens = int64(len(tokenize(userMessage)))
74+
6875
// run request processing workers
6976
for i := 1; i <= s.config.MaxNumSeqs; i++ {
7077
go s.reqProcessingWorker(ctx, i)
@@ -132,17 +139,19 @@ var _ = Describe("Simulator", func() {
132139
}
133140

134141
Expect(numberOfChunksWithUsage).To(Equal(1))
135-
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
142+
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
136143
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
137144
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))
138145

139146
msg := strings.Join(tokens, "")
140-
expectedMsg := userMessage
141147
if mode == modeRandom {
142-
expectedMsg = getFullTextFromPartialString(msg)
148+
// in case of random mode ensure that the returned message could be output of the random text generator
149+
Expect(isValidText(msg)).To(BeTrue())
150+
} else {
151+
// in case of echo mode check that the text is returned as-is
152+
Expect(msg).Should(Equal(userMessage))
143153
}
144154
Expect(role).Should(Equal("assistant"))
145-
Expect(msg).Should(Equal(expectedMsg))
146155
},
147156
func(mode string) string {
148157
return "mode: " + mode
@@ -189,16 +198,18 @@ var _ = Describe("Simulator", func() {
189198
Expect(string(chunk.Object)).To(Equal(textCompletionObject))
190199
}
191200
Expect(numberOfChunksWithUsage).To(Equal(1))
192-
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
201+
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
193202
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
194203
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))
195204

196205
text := strings.Join(tokens, "")
197-
expectedText := userMessage
198206
if mode == modeRandom {
199-
expectedText = getFullTextFromPartialString(text)
207+
// in case of random mode ensure that the returned message could be output of the random text generator
208+
Expect(isValidText(text)).To(BeTrue())
209+
} else {
210+
// in case of echo mode check that the text is returned as-is
211+
Expect(text).Should(Equal(userMessage))
200212
}
201-
Expect(text).Should(Equal(expectedText))
202213
},
203214
func(mode string) string {
204215
return "mode: " + mode
@@ -224,18 +235,15 @@ var _ = Describe("Simulator", func() {
224235
Model: model,
225236
}
226237
numTokens := 0
227-
partialErrMsg := ""
228238
// if maxTokens and maxCompletionTokens are passsed
229239
// maxCompletionTokens is used
230240
if maxTokens != 0 {
231241
params.MaxTokens = param.NewOpt(int64(maxTokens))
232242
numTokens = maxTokens
233-
partialErrMsg = "max_tokens must be at least 1, got -1"
234243
}
235244
if maxCompletionTokens != 0 {
236245
params.MaxCompletionTokens = param.NewOpt(int64(maxCompletionTokens))
237246
numTokens = maxCompletionTokens
238-
partialErrMsg = "max_completion_tokens must be at least 1, got -1"
239247
}
240248
resp, err := openaiclient.Chat.Completions.New(ctx, params)
241249
if err != nil {
@@ -244,7 +252,7 @@ var _ = Describe("Simulator", func() {
244252
if openaiError.StatusCode == 400 {
245253
errMsg, err := io.ReadAll(openaiError.Response.Body)
246254
Expect(err).NotTo(HaveOccurred())
247-
if strings.Contains(string(errMsg), partialErrMsg) {
255+
if strings.Contains(string(errMsg), invalidMaxTokensErrMsg) {
248256
return
249257
}
250258
}
@@ -254,22 +262,24 @@ var _ = Describe("Simulator", func() {
254262
Expect(resp.Choices).ShouldNot(BeEmpty())
255263
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
256264

257-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
265+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
258266
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
259267
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
260268

261269
msg := resp.Choices[0].Message.Content
262270
Expect(msg).ShouldNot(BeEmpty())
263271

264272
if numTokens > 0 {
265-
tokens := strings.Fields(msg)
273+
tokens := tokenize(msg)
266274
Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
267275
} else {
268-
expectedMsg := userMessage
269276
if mode == modeRandom {
270-
expectedMsg = getFullTextFromPartialString(msg)
277+
// in case of random mode ensure that the returned message could be output of the random text generator
278+
Expect(isValidText(msg)).To(BeTrue())
279+
} else {
280+
// in case of echo mode check that the text is returned as-is
281+
Expect(msg).Should(Equal(userMessage))
271282
}
272-
Expect(msg).Should(Equal(expectedMsg))
273283
}
274284
},
275285
func(mode string, maxTokens int, maxCompletionTokens int) string {
@@ -310,7 +320,6 @@ var _ = Describe("Simulator", func() {
310320
Model: openai.CompletionNewParamsModel(model),
311321
}
312322
numTokens := 0
313-
partialErrMsg := "max_tokens must be at least 1, got -1"
314323
if maxTokens != 0 {
315324
params.MaxTokens = param.NewOpt(int64(maxTokens))
316325
numTokens = maxTokens
@@ -322,7 +331,7 @@ var _ = Describe("Simulator", func() {
322331
if openaiError.StatusCode == 400 {
323332
errMsg, err := io.ReadAll(openaiError.Response.Body)
324333
Expect(err).NotTo(HaveOccurred())
325-
if strings.Contains(string(errMsg), partialErrMsg) {
334+
if strings.Contains(string(errMsg), invalidMaxTokensErrMsg) {
326335
return
327336
}
328337
}
@@ -332,22 +341,24 @@ var _ = Describe("Simulator", func() {
332341
Expect(resp.Choices).ShouldNot(BeEmpty())
333342
Expect(string(resp.Object)).To(Equal(textCompletionObject))
334343

335-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
344+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
336345
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
337346
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
338347

339348
text := resp.Choices[0].Text
340349
Expect(text).ShouldNot(BeEmpty())
341350

342351
if numTokens != 0 {
343-
tokens := strings.Fields(text)
352+
tokens := tokenize(text)
344353
Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens))
345354
} else {
346-
expectedText := userMessage
347355
if mode == modeRandom {
348-
expectedText = getFullTextFromPartialString(text)
356+
// in case of random mode ensure that the returned message could be output of the random text generator
357+
Expect(isValidText(text)).To(BeTrue())
358+
} else {
359+
// in case of echo mode check that the text is returned as-is
360+
Expect(text).Should(Equal(userMessage))
349361
}
350-
Expect(text).Should(Equal(expectedText))
351362
}
352363
},
353364
func(mode string, maxTokens int) string {

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ var _ = Describe("Simulator for request with tools", func() {
398398
}
399399

400400
Expect(numberOfChunksWithUsage).To(Equal(1))
401-
Expect(chunk.Usage.PromptTokens).To(Equal(int64(4)))
401+
Expect(chunk.Usage.PromptTokens).To(Equal(userMsgTokens))
402402
Expect(chunk.Usage.CompletionTokens).To(BeNumerically(">", 0))
403403
Expect(chunk.Usage.TotalTokens).To(Equal(chunk.Usage.PromptTokens + chunk.Usage.CompletionTokens))
404404

@@ -451,7 +451,7 @@ var _ = Describe("Simulator for request with tools", func() {
451451
Expect(resp.Choices).ShouldNot(BeEmpty())
452452
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
453453

454-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
454+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
455455
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
456456
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
457457

@@ -543,7 +543,7 @@ var _ = Describe("Simulator for request with tools", func() {
543543
Expect(resp.Choices).ShouldNot(BeEmpty())
544544
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
545545

546-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
546+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
547547
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
548548
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
549549

@@ -599,7 +599,7 @@ var _ = Describe("Simulator for request with tools", func() {
599599
Expect(resp.Choices).ShouldNot(BeEmpty())
600600
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
601601

602-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
602+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
603603
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
604604
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
605605

@@ -685,7 +685,7 @@ var _ = Describe("Simulator for request with tools", func() {
685685
Expect(resp.Choices).ShouldNot(BeEmpty())
686686
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
687687

688-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
688+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
689689
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
690690
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
691691

@@ -747,7 +747,7 @@ var _ = Describe("Simulator for request with tools", func() {
747747
Expect(resp.Choices).ShouldNot(BeEmpty())
748748
Expect(string(resp.Object)).To(Equal(chatCompletionObject))
749749

750-
Expect(resp.Usage.PromptTokens).To(Equal(int64(4)))
750+
Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens))
751751
Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0))
752752
Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens))
753753

0 commit comments

Comments
 (0)