Skip to content

Commit c09b608

Browse files
committed
Generate tokens instead of strings
Signed-off-by: Qifan Deng <[email protected]>
1 parent 574342e commit c09b608

File tree

3 files changed

+34
-62
lines changed

3 files changed

+34
-62
lines changed

pkg/common/utils.go

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
package common
1818

1919
import (
20-
"fmt"
2120
"math"
2221
"math/rand"
2322
"regexp"
@@ -73,26 +72,6 @@ func init() {
7372
}
7473
}
7574

76-
// returns the max tokens or error if incorrect
77-
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
78-
var typeToken string
79-
var tokens *int64
80-
// if both arguments are passed,
81-
// use maxCompletionTokens
82-
// as in the real vllm
83-
if maxCompletionTokens != nil {
84-
tokens = maxCompletionTokens
85-
typeToken = "max_completion_tokens"
86-
} else if maxTokens != nil {
87-
tokens = maxTokens
88-
typeToken = "max_tokens"
89-
}
90-
if tokens != nil && *tokens < 1 {
91-
return nil, fmt.Errorf("%s must be at least 1, got %d", typeToken, *tokens)
92-
}
93-
return tokens, nil
94-
}
95-
9675
// ValidateContextWindow checks if the request fits within the model's context window
9776
// Returns validation result, actual completion tokens, and total tokens
9877
func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) {
@@ -157,7 +136,7 @@ func GetRandomText(numOfTokens int) string {
157136
return strings.Join(allTokens, "")
158137
}
159138

160-
// GetRandomResponseText generates text to be returned in a response, and the finish reason (stop or length)
139+
// GetRandomTokens generates tokens to be returned in a response, and the finish reason (stop or length)
161140
// if maxCompletionTokens is defined
162141
// - currently, the generated number of words in the text will be equal to it value
163142
// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests
@@ -167,7 +146,7 @@ func GetRandomText(numOfTokens int) string {
167146
// - finish reason is stop
168147
// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens
169148
// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined
170-
func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) {
149+
func GetRandomTokens(maxCompletionTokens *int64, ignore_eos bool) ([]string, string) {
171150
numOfTokens := 0
172151
finishReason := StopFinishReason
173152

@@ -189,8 +168,7 @@ func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string,
189168
}
190169
}
191170

192-
text := GetRandomText(numOfTokens)
193-
return text, finishReason
171+
return Tokenize(GetRandomText(numOfTokens)), finishReason
194172
}
195173

196174
// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
@@ -282,23 +260,20 @@ func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) {
282260
return start, end
283261
}
284262

285-
// GetResponseText returns response text, from a given text
263+
// GetResponseTokens returns needed tokens, from a given text
286264
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
287-
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {
265+
func GetResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) {
266+
tokens := Tokenize(text)
288267
// no max completion tokens, return entire text
289268
if maxCompletionTokens == nil {
290-
return text, StopFinishReason
269+
return tokens, StopFinishReason
291270
}
292271

293-
// create tokens from text, splitting by spaces
294-
tokens := Tokenize(text)
295-
296-
// return entire text
297272
if *maxCompletionTokens >= int64(len(tokens)) {
298-
return text, StopFinishReason
273+
return tokens, StopFinishReason
299274
}
300275
// return truncated text
301-
return strings.Join(tokens[0:*maxCompletionTokens], " "), LengthFinishReason
276+
return tokens[0:*maxCompletionTokens], LengthFinishReason
302277
}
303278

304279
func RandomNumericString(length int) string {

pkg/common/utils_test.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package common
1818

1919
import (
2020
"fmt"
21+
"strings"
2122
"time"
2223

2324
. "github.com/onsi/ginkgo/v2"
@@ -29,16 +30,17 @@ var _ = Describe("Utils", Ordered, func() {
2930
InitRandom(time.Now().UnixNano())
3031
})
3132

32-
Context("GetRandomResponseText", func() {
33+
Context("GetRandomTokens", func() {
3334
It("should return complete text", func() {
34-
text, finishReason := GetRandomResponseText(nil, false)
35+
tokens, finishReason := GetRandomTokens(nil, false)
36+
text := strings.Join(tokens, "")
3537
Expect(IsValidText(text)).To(BeTrue())
3638
Expect(finishReason).Should(Equal(StopFinishReason))
3739
})
3840
It("should return short text", func() {
3941
maxCompletionTokens := int64(2)
40-
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
41-
tokensCnt := int64(len(Tokenize(text)))
42+
tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false)
43+
tokensCnt := int64(len(tokens))
4244
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
4345
if tokensCnt == maxCompletionTokens {
4446
Expect(finishReason).To(Equal(LengthFinishReason))
@@ -50,9 +52,10 @@ var _ = Describe("Utils", Ordered, func() {
5052
It("should return long text", func() {
5153
// return required number of tokens although it is higher than ResponseLenMax
5254
maxCompletionTokens := int64(ResponseLenMax * 5)
53-
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
54-
tokensCnt := int64(len(Tokenize(text)))
55+
tokens, finishReason := GetRandomTokens(&maxCompletionTokens, false)
56+
tokensCnt := int64(len(tokens))
5557
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
58+
text := strings.Join(tokens, "")
5659
Expect(IsValidText(text)).To(BeTrue())
5760
if tokensCnt == maxCompletionTokens {
5861
Expect(finishReason).To(Equal(LengthFinishReason))
@@ -65,8 +68,8 @@ var _ = Describe("Utils", Ordered, func() {
6568
DescribeTable("should return exact num of tokens",
6669
func(maxCompletionTokens int) {
6770
n := int64(maxCompletionTokens)
68-
text, finishReason := GetRandomResponseText(&n, true)
69-
nGenTokens := int64(len(Tokenize(text)))
71+
tokens, finishReason := GetRandomTokens(&n, true)
72+
nGenTokens := int64(len(tokens))
7073
Expect(nGenTokens).Should(Equal(n))
7174
Expect(finishReason).To(Equal(LengthFinishReason))
7275
},
@@ -80,24 +83,25 @@ var _ = Describe("Utils", Ordered, func() {
8083
)
8184
})
8285

83-
Context("GetResponseText", func() {
86+
Context("GetResponseTokens", func() {
8487
theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime"
88+
theTokens := Tokenize(theText)
8589

8690
It("should return the same text since max tokens is not defined", func() {
87-
text, finishReason := GetResponseText(nil, theText)
88-
Expect(text).Should(Equal(theText))
91+
tokens, finishReason := GetResponseTokens(nil, theText)
92+
Expect(tokens).Should(Equal(theTokens))
8993
Expect(finishReason).Should(Equal(StopFinishReason))
9094
})
9195
It("should return the same text since max tokens is higher than the text length", func() {
9296
maxCompletionTokens := int64(1000)
93-
text, finishReason := GetResponseText(&maxCompletionTokens, theText)
94-
Expect(text).Should(Equal(theText))
97+
tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText)
98+
Expect(tokens).Should(Equal(theTokens))
9599
Expect(finishReason).Should(Equal(StopFinishReason))
96100
})
97101
It("should return partial text", func() {
98102
maxCompletionTokens := int64(2)
99-
text, finishReason := GetResponseText(&maxCompletionTokens, theText)
100-
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
103+
tokens, finishReason := GetResponseTokens(&maxCompletionTokens, theText)
104+
Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens))
101105
Expect(finishReason).Should(Equal(LengthFinishReason))
102106
})
103107
})

pkg/llm-d-inference-sim/simulator.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,8 @@ func (s *VllmSimulator) GetInterTokenLatency() int {
599599
}
600600

601601
// generateTokens creates and returns response payload based on this request,
602-
// i.e., an array of generated tokens, the finish reason, and the number of created tokens
602+
// i.e., an array of generated tokens, the finish reason, and the number of generated tokens
603603
func (s *VllmSimulator) generateTokens(req openaiserverapi.CompletionRequest) ([]string, string, int, error) {
604-
// if req is ChatCompletionRequest
605604
ignoreEOS := req.GetIgnoreEOS()
606605
var maxTokens *int64
607606
var prompt string
@@ -616,19 +615,13 @@ func (s *VllmSimulator) generateTokens(req openaiserverapi.CompletionRequest) ([
616615
return nil, "", 0, fmt.Errorf("unknown request type: %T", req)
617616
}
618617

619-
maxTokensValue, err := common.GetMaxTokens(nil, maxTokens)
620-
if err != nil {
621-
return nil, "", 0, err
622-
}
623-
624-
var text, finishReason string
618+
var finishReason string
619+
var tokens []string
625620
if s.config.Mode == common.ModeEcho {
626-
text, finishReason = common.GetResponseText(maxTokensValue, prompt)
627-
} else {
628-
text, finishReason = common.GetRandomResponseText(maxTokensValue, ignoreEOS)
621+
tokens, finishReason = common.GetResponseTokens(maxTokens, prompt)
622+
return tokens, finishReason, len(tokens), nil
629623
}
630-
631-
tokens := common.Tokenize(text)
624+
tokens, finishReason = common.GetRandomTokens(maxTokens, ignoreEOS)
632625
return tokens, finishReason, len(tokens), nil
633626
}
634627
>>>>>>> 48ec8bc (Move token generation to simulator)

0 commit comments

Comments
 (0)