diff --git a/pkg/common/utils.go b/pkg/common/utils.go index a04692dc..31f5f288 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -165,7 +165,9 @@ func GetRandomText(numOfTokens int) string { // if maxCompletionTokens is nil // - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters // - finish reason is stop -func GetRandomResponseText(maxCompletionTokens *int64) (string, string) { +// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens +// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined +func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) { numOfTokens := 0 finishReason := StopFinishReason @@ -174,11 +176,16 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) { numOfTokens = GetRandomResponseLen() } else { maxTokens := int(*maxCompletionTokens) - // max tokens is defined - generate real length of the response based on it - numOfTokens = getResponseLengthByHistogram(maxTokens) - if numOfTokens == maxTokens { - // if response should be create with maximum number of tokens - finish reason will be 'length' + if ignore_eos { + numOfTokens = maxTokens finishReason = LengthFinishReason + } else { + // max tokens is defined - generate real length of the response based on it + numOfTokens = getResponseLengthByHistogram(maxTokens) + if numOfTokens == maxTokens { + // if response should be create with maximum number of tokens - finish reason will be 'length' + finishReason = LengthFinishReason + } } } diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index b05b0e31..d847df35 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -31,13 +31,13 @@ var _ = Describe("Utils", Ordered, func() { Context("GetRandomResponseText", func() { It("should return complete text", func() { - text, finishReason := GetRandomResponseText(nil) + text, finishReason := GetRandomResponseText(nil, false) Expect(IsValidText(text)).To(BeTrue()) Expect(finishReason).Should(Equal(StopFinishReason)) }) It("should return short text", func() { maxCompletionTokens := int64(2) - text, finishReason := GetRandomResponseText(&maxCompletionTokens) + text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) tokensCnt := int64(len(Tokenize(text))) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) if tokensCnt == maxCompletionTokens { @@ -50,7 +50,7 @@ var _ = Describe("Utils", Ordered, func() { It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) - text, finishReason := GetRandomResponseText(&maxCompletionTokens) + text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) tokensCnt := int64(len(Tokenize(text))) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) Expect(IsValidText(text)).To(BeTrue()) @@ -61,6 +61,23 @@ var _ = Describe("Utils", Ordered, func() { Expect(finishReason).To(Equal(StopFinishReason)) } }) + + DescribeTable("should return exact num of tokens", + func(maxCompletionTokens int) { + n := int64(maxCompletionTokens) + text, finishReason := GetRandomResponseText(&n, true) + nGenTokens := int64(len(Tokenize(text))) + Expect(nGenTokens).Should(Equal(n)) + Expect(finishReason).To(Equal(LengthFinishReason)) + }, + func(maxCompletionTokens int) string { + return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) + }, + Entry("1", 1), + Entry("42", 42), + Entry("99", 99), + Entry("10000", 10000), + ) }) Context("GetResponseText", func() { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 026a55c4..834df408 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -326,6 +326,10 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) ( return "Prefill does not support streaming", fasthttp.StatusBadRequest } + if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil { + return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest + } + return "", fasthttp.StatusOK } diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 675db162..e7d5fb3b 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -59,6 +59,8 @@ type CompletionRequest interface { GetToolChoice() string // GetMaxCompletionTokens returns the maximum completion tokens requested GetMaxCompletionTokens() *int64 + // GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored + GetIgnoreEOS() bool // IsDoRemoteDecode() returns true if do_remote_decode field is true in the request, // when the field is true, the decode phase should be done on remote pod, // whereas prefill phase is done on local pod, thus this is a prefill request @@ -93,6 +95,8 @@ type baseCompletionRequest struct { RemotePort int `json:"remote_port"` // The number of tokens in the prompt that are in the local KV Cache cachedPromptTokens int + // IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens + IgnoreEOS bool `json:"ignore_eos"` } // StreamOptions defines streaming options for streaming requests @@ -131,6 +135,11 @@ func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int { return b.cachedPromptTokens } +// GetIgnoreEOS returns the value of IgnoreEOS +func (b *baseCompletionRequest) GetIgnoreEOS() bool { + return b.IgnoreEOS +} + // SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are // in the local KV Cache func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { @@ -244,7 +253,7 @@ func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, stri if mode == common.ModeEcho { text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg()) } else { - text, finishReason = common.GetRandomResponseText(maxTokens) + text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) } tokens := common.Tokenize(text) @@ -299,7 +308,7 @@ func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, stri if mode == common.ModeEcho { text, finishReason = common.GetResponseText(maxTokens, req.Prompt) } else { - text, finishReason = common.GetRandomResponseText(maxTokens) + text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) } tokens := common.Tokenize(text)