Skip to content

Commit 46e5d1e

Browse files
committed
Respect maxTokens
Signed-off-by: Qifan Deng <[email protected]>
1 parent 26a0b8f commit 46e5d1e

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

pkg/dataset/custom_dataset.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
356356
return d.echo(req)
357357
}
358358
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
359-
tokens, err := d.GenerateTokens(req, nTokensToGen)
359+
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason)
360360
return tokens, finishReason, err
361361
}
362362

@@ -377,17 +377,36 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
377377
return unmarshalAllRecords(rows)
378378
}
379379

380-
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
380+
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) {
381381
// query by prompt hash first
382382
promptHash := d.GetPromptHash(req)
383383
promptHashHex := d.GetPromptHashHex(promptHash)
384384
query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
385385
tokensList, err := d.query(query, nTokens)
386386

387-
if err != nil || len(tokensList) == 0 {
388-
// if query by prompt hash fails, fallback to query by number of tokens
389-
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
390-
tokensList, err = d.query(query, nTokens)
387+
// filter out results according to finish reason
388+
var filteredTokensList [][]string
389+
if finishReason != LengthFinishReason && finishReason != StopFinishReason {
390+
d.Logger.Error(errors.New("unknown finish reason"), "Unexpected finish reason", "reason", finishReason)
391+
}
392+
for _, tokens := range tokensList {
393+
if finishReason == StopFinishReason && len(tokens) <= nTokens {
394+
filteredTokensList = append(filteredTokensList, tokens)
395+
} else if finishReason == LengthFinishReason && len(tokens) == nTokens {
396+
filteredTokensList = append(filteredTokensList, tokens)
397+
}
398+
}
399+
tokensList = filteredTokensList
400+
401+
if err != nil || len(filteredTokensList) == 0 {
402+
switch finishReason {
403+
case LengthFinishReason:
404+
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
405+
tokensList, err = d.query(query, nTokens)
406+
case StopFinishReason:
407+
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";"
408+
tokensList, err = d.query(query, nTokens)
409+
}
391410
}
392411

393412
if err != nil || len(tokensList) == 0 {

pkg/dataset/custom_dataset_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,18 @@ var _ = Describe("CustomDataset", Ordered, func() {
184184
Expect(finishReason).To(Equal(StopFinishReason))
185185
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
186186
})
187+
188+
It("should return at most 2 tokens for existing prompt", func() {
189+
err := dataset.Init(validDBPath, "")
190+
Expect(err).NotTo(HaveOccurred())
191+
n := int64(2)
192+
req := &openaiserverapi.TextCompletionRequest{
193+
Prompt: testPrompt,
194+
MaxTokens: &n,
195+
}
196+
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
197+
Expect(err).NotTo(HaveOccurred())
198+
Expect(finishReason).To(Equal(LengthFinishReason))
199+
Expect(len(tokens)).To(BeNumerically("<=", 2))
200+
})
187201
})

0 commit comments

Comments
 (0)