Skip to content

Commit e442062

Browse files
authored
Add ignore eos in request (#187)
* Add ignore eos in request Signed-off-by: Qifan Deng <[email protected]> * Respect ignore eos Signed-off-by: Qifan Deng <[email protected]> * Improve comments and remove duplicated code Signed-off-by: Qifan Deng <[email protected]> --------- Signed-off-by: Qifan Deng <[email protected]>
1 parent 639b40e commit e442062

File tree

4 files changed

+47
-10
lines changed

4 files changed

+47
-10
lines changed

pkg/common/utils.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ func GetRandomText(numOfTokens int) string {
165165
// if maxCompletionTokens is nil
166166
// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters
167167
// - finish reason is stop
168-
func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
168+
// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens
169+
// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined
170+
func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) {
169171
numOfTokens := 0
170172
finishReason := StopFinishReason
171173

@@ -174,11 +176,16 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
174176
numOfTokens = GetRandomResponseLen()
175177
} else {
176178
maxTokens := int(*maxCompletionTokens)
177-
// max tokens is defined - generate real length of the response based on it
178-
numOfTokens = getResponseLengthByHistogram(maxTokens)
179-
if numOfTokens == maxTokens {
180-
// if response should be create with maximum number of tokens - finish reason will be 'length'
179+
if ignore_eos {
180+
numOfTokens = maxTokens
181181
finishReason = LengthFinishReason
182+
} else {
183+
// max tokens is defined - generate real length of the response based on it
184+
numOfTokens = getResponseLengthByHistogram(maxTokens)
185+
if numOfTokens == maxTokens {
186+
// if response should be create with maximum number of tokens - finish reason will be 'length'
187+
finishReason = LengthFinishReason
188+
}
182189
}
183190
}
184191

pkg/common/utils_test.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ var _ = Describe("Utils", Ordered, func() {
3131

3232
Context("GetRandomResponseText", func() {
3333
It("should return complete text", func() {
34-
text, finishReason := GetRandomResponseText(nil)
34+
text, finishReason := GetRandomResponseText(nil, false)
3535
Expect(IsValidText(text)).To(BeTrue())
3636
Expect(finishReason).Should(Equal(StopFinishReason))
3737
})
3838
It("should return short text", func() {
3939
maxCompletionTokens := int64(2)
40-
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
40+
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
4141
tokensCnt := int64(len(Tokenize(text)))
4242
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
4343
if tokensCnt == maxCompletionTokens {
@@ -50,7 +50,7 @@ var _ = Describe("Utils", Ordered, func() {
5050
It("should return long text", func() {
5151
// return required number of tokens although it is higher than ResponseLenMax
5252
maxCompletionTokens := int64(ResponseLenMax * 5)
53-
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
53+
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
5454
tokensCnt := int64(len(Tokenize(text)))
5555
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
5656
Expect(IsValidText(text)).To(BeTrue())
@@ -61,6 +61,23 @@ var _ = Describe("Utils", Ordered, func() {
6161
Expect(finishReason).To(Equal(StopFinishReason))
6262
}
6363
})
64+
65+
DescribeTable("should return exact num of tokens",
66+
func(maxCompletionTokens int) {
67+
n := int64(maxCompletionTokens)
68+
text, finishReason := GetRandomResponseText(&n, true)
69+
nGenTokens := int64(len(Tokenize(text)))
70+
Expect(nGenTokens).Should(Equal(n))
71+
Expect(finishReason).To(Equal(LengthFinishReason))
72+
},
73+
func(maxCompletionTokens int) string {
74+
return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens)
75+
},
76+
Entry("1", 1),
77+
Entry("42", 42),
78+
Entry("99", 99),
79+
Entry("10000", 10000),
80+
)
6481
})
6582

6683
Context("GetResponseText", func() {

pkg/llm-d-inference-sim/simulator.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
326326
return "Prefill does not support streaming", fasthttp.StatusBadRequest
327327
}
328328

329+
if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil {
330+
return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest
331+
}
332+
329333
return "", fasthttp.StatusOK
330334
}
331335

pkg/openai-server-api/request.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ type CompletionRequest interface {
5959
GetToolChoice() string
6060
// GetMaxCompletionTokens returns the maximum completion tokens requested
6161
GetMaxCompletionTokens() *int64
62+
// GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored
63+
GetIgnoreEOS() bool
6264
// IsDoRemoteDecode() returns true if do_remote_decode field is true in the request,
6365
// when the field is true, the decode phase should be done on remote pod,
6466
// whereas prefill phase is done on local pod, thus this is a prefill request
@@ -93,6 +95,8 @@ type baseCompletionRequest struct {
9395
RemotePort int `json:"remote_port"`
9496
// The number of tokens in the prompt that are in the local KV Cache
9597
cachedPromptTokens int
98+
// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
99+
IgnoreEOS bool `json:"ignore_eos"`
96100
}
97101

98102
// StreamOptions defines streaming options for streaming requests
@@ -131,6 +135,11 @@ func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int {
131135
return b.cachedPromptTokens
132136
}
133137

138+
// GetIgnoreEOS returns the value of IgnoreEOS
139+
func (b *baseCompletionRequest) GetIgnoreEOS() bool {
140+
return b.IgnoreEOS
141+
}
142+
134143
// SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are
135144
// in the local KV Cache
136145
func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) {
@@ -244,7 +253,7 @@ func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, stri
244253
if mode == common.ModeEcho {
245254
text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg())
246255
} else {
247-
text, finishReason = common.GetRandomResponseText(maxTokens)
256+
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
248257
}
249258

250259
tokens := common.Tokenize(text)
@@ -299,7 +308,7 @@ func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, stri
299308
if mode == common.ModeEcho {
300309
text, finishReason = common.GetResponseText(maxTokens, req.Prompt)
301310
} else {
302-
text, finishReason = common.GetRandomResponseText(maxTokens)
311+
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
303312
}
304313

305314
tokens := common.Tokenize(text)

0 commit comments

Comments
 (0)