Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ func GetRandomText(numOfTokens int) string {
// if maxCompletionTokens is nil
// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters
// - finish reason is stop
func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens
// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined
func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) {
numOfTokens := 0
finishReason := StopFinishReason

Expand All @@ -174,11 +176,16 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
numOfTokens = GetRandomResponseLen()
} else {
maxTokens := int(*maxCompletionTokens)
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
if ignore_eos {
numOfTokens = maxTokens
finishReason = LengthFinishReason
} else {
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
finishReason = LengthFinishReason
}
}
}

Expand Down
23 changes: 20 additions & 3 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ var _ = Describe("Utils", Ordered, func() {

Context("GetRandomResponseText", func() {
It("should return complete text", func() {
text, finishReason := GetRandomResponseText(nil)
text, finishReason := GetRandomResponseText(nil, false)
Expect(IsValidText(text)).To(BeTrue())
Expect(finishReason).Should(Equal(StopFinishReason))
})
It("should return short text", func() {
maxCompletionTokens := int64(2)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
Expand All @@ -50,7 +50,7 @@ var _ = Describe("Utils", Ordered, func() {
It("should return long text", func() {
// return required number of tokens although it is higher than ResponseLenMax
maxCompletionTokens := int64(ResponseLenMax * 5)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
text, finishReason := GetRandomResponseText(&maxCompletionTokens, false)
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
Expect(IsValidText(text)).To(BeTrue())
Expand All @@ -61,6 +61,23 @@ var _ = Describe("Utils", Ordered, func() {
Expect(finishReason).To(Equal(StopFinishReason))
}
})

DescribeTable("should return exact num of tokens",
func(maxCompletionTokens int) {
n := int64(maxCompletionTokens)
text, finishReason := GetRandomResponseText(&n, true)
nGenTokens := int64(len(Tokenize(text)))
Expect(nGenTokens).Should(Equal(n))
Expect(finishReason).To(Equal(LengthFinishReason))
},
func(maxCompletionTokens int) string {
return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens)
},
Entry("1", 1),
Entry("42", 42),
Entry("99", 99),
Entry("10000", 10000),
)
})

Context("GetResponseText", func() {
Expand Down
4 changes: 4 additions & 0 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
return "Prefill does not support streaming", fasthttp.StatusBadRequest
}

if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil {
return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest
}

return "", fasthttp.StatusOK
}

Expand Down
13 changes: 11 additions & 2 deletions pkg/openai-server-api/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type CompletionRequest interface {
GetToolChoice() string
// GetMaxCompletionTokens returns the maximum completion tokens requested
GetMaxCompletionTokens() *int64
// GetIgnoreEOS returns true if the end-of-sequence tokens will be ignored
GetIgnoreEOS() bool
// IsDoRemoteDecode() returns true if do_remote_decode field is true in the request,
// when the field is true, the decode phase should be done on remote pod,
// whereas prefill phase is done on local pod, thus this is a prefill request
Expand Down Expand Up @@ -93,6 +95,8 @@ type baseCompletionRequest struct {
RemotePort int `json:"remote_port"`
// The number of tokens in the prompt that are in the local KV Cache
cachedPromptTokens int
// IgnoreEOS is a boolean value, true when the model should ignore end-of-sequence tokens
IgnoreEOS bool `json:"ignore_eos"`
}

// StreamOptions defines streaming options for streaming requests
Expand Down Expand Up @@ -131,6 +135,11 @@ func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int {
return b.cachedPromptTokens
}

// GetIgnoreEOS returns the value of IgnoreEOS
func (b *baseCompletionRequest) GetIgnoreEOS() bool {
return b.IgnoreEOS
}

// SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are
// in the local KV Cache
func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) {
Expand Down Expand Up @@ -244,7 +253,7 @@ func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, stri
if mode == common.ModeEcho {
text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg())
} else {
text, finishReason = common.GetRandomResponseText(maxTokens)
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
}

tokens := common.Tokenize(text)
Expand Down Expand Up @@ -299,7 +308,7 @@ func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, stri
if mode == common.ModeEcho {
text, finishReason = common.GetResponseText(maxTokens, req.Prompt)
} else {
text, finishReason = common.GetRandomResponseText(maxTokens)
text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS())
}

tokens := common.Tokenize(text)
Expand Down
Loading