Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func GetRandomText(numOfTokens int) string {
// if maxCompletionTokens is nil
// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters
// - finish reason is stop
func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) {
numOfTokens := 0
finishReason := StopFinishReason

Expand All @@ -174,11 +174,16 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
numOfTokens = GetRandomResponseLen()
} else {
maxTokens := int(*maxCompletionTokens)
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
if ignore_eos {
numOfTokens = maxTokens
finishReason = LengthFinishReason
} else {
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
finishReason = LengthFinishReason
}
}
}

Expand Down
23 changes: 20 additions & 3 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ var _ = Describe("Utils", Ordered, func() {

Context("GetRandomResponseText", func() {
It("should return complete text", func() {
text, finishReason := GetRandomResponseText(nil)
text, finishReason := GetRandomResponseText(nil, false)
Expect(IsValidText(text)).To(BeTrue())
Expect(finishReason).Should(Equal(StopFinishReason))
})
It("should return short text", func() {
maxCompletionTokens := int64(2)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
Expand All @@ -50,7 +50,7 @@ var _ = Describe("Utils", Ordered, func() {
It("should return long text", func() {
// return required number of tokens although it is higher than ResponseLenMax
maxCompletionTokens := int64(ResponseLenMax * 5)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
Expect(IsValidText(text)).To(BeTrue())
Expand All @@ -61,6 +61,23 @@ var _ = Describe("Utils", Ordered, func() {
Expect(finishReason).To(Equal(StopFinishReason))
}
})

DescribeTable("should return exact num of tokens",
func(maxCompletionTokens int) {
n := int64(maxCompletionTokens)
text, finishReason := GetRandomResponseText(&n, true)
nGenTokens := int64(len(Tokenize(text)))
Expect(nGenTokens).Should(Equal(n))
Expect(finishReason).To(Equal(LengthFinishReason))
},
func(maxCompletionTokens int) string {
return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens)
},
Entry("1", 1),
Entry("42", 42),
Entry("99", 99),
Entry("10000", 10000),
)
})

Context("GetResponseText", func() {
Expand Down
4 changes: 4 additions & 0 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
return "Prefill does not support streaming", fasthttp.StatusBadRequest
}

if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil {
return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest
}

return "", fasthttp.StatusOK
}

Expand Down
20 changes: 18 additions & 2 deletions pkg/openai-server-api/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type CompletionRequest interface {
GetToolChoice() string
// GetMaxCompletionTokens returns the maximum completion tokens requested
GetMaxCompletionTokens() *int64
// GetIgnoreEOS returns true if the end of sequence will be ignored
GetIgnoreEOS() bool
// IsDoRemoteDecode() returns true if do_remote_decode field is true in the request,
// when the field is true, the decode phase should be done on remote pod,
// whereas prefill phase is done on local pod, thus this is a prefill request
Expand Down Expand Up @@ -164,6 +166,9 @@ type ChatCompletionRequest struct {
// tokens and reasoning tokens.
MaxCompletionTokens *int64 `json:"max_completion_tokens"`

// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
IgnoreEOS bool `json:"ignore_eos"` // Field remains unchanged

// Tools is a list of tools the model may call.
Tools []Tool `json:"tools,omitempty"`

Expand Down Expand Up @@ -219,6 +224,10 @@ func (c *ChatCompletionRequest) GetMaxCompletionTokens() *int64 {
return c.MaxTokens
}

func (c *ChatCompletionRequest) GetIgnoreEOS() bool {
return c.IgnoreEOS
}

// getLastUserMsg returns last message from this request's messages with user role,
// if does not exist - returns an empty string
func (req *ChatCompletionRequest) getLastUserMsg() string {
Expand All @@ -244,7 +253,7 @@ func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, stri
if mode == common.ModeEcho {
text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg())
} else {
text, finishReason = common.GetRandomResponseText(maxTokens)
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
}

tokens := common.Tokenize(text)
Expand All @@ -264,6 +273,9 @@ type TextCompletionRequest struct {
// The token count of your prompt plus `max_tokens` cannot exceed the model's
// context length.
MaxTokens *int64 `json:"max_tokens"`

// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
IgnoreEOS bool `json:"ignore_eos"`
}

func (t *TextCompletionRequest) GetPrompt() string {
Expand All @@ -286,6 +298,10 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 {
return c.MaxTokens
}

func (c *TextCompletionRequest) GetIgnoreEOS() bool {
return c.IgnoreEOS
}

// CreateResponseText creates and returns response payload based on this request,
// i.e., an array of generated tokens, the finish reason, and the number of created
// tokens
Expand All @@ -299,7 +315,7 @@ func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, stri
if mode == common.ModeEcho {
text, finishReason = common.GetResponseText(maxTokens, req.Prompt)
} else {
text, finishReason = common.GetRandomResponseText(maxTokens)
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
}

tokens := common.Tokenize(text)
Expand Down
Loading