@@ -356,7 +356,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
356356 return d .echo (req )
357357 }
358358 nTokensToGen , finishReason := howManyTokensToGen (d .extractMaxTokens (req ), req .GetIgnoreEOS ())
359- tokens , err := d .GenerateTokens (req , nTokensToGen )
359+ tokens , err := d .GenerateTokens (req , nTokensToGen , finishReason )
360360 return tokens , finishReason , err
361361}
362362
@@ -377,17 +377,36 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
377377 return unmarshalAllRecords (rows )
378378}
379379
380- func (d * CustomDataset ) GenerateTokens (req openaiserverapi.CompletionRequest , nTokens int ) ([]string , error ) {
380+ func (d * CustomDataset ) GenerateTokens (req openaiserverapi.CompletionRequest , nTokens int , finishReason string ) ([]string , error ) {
381381 // query by prompt hash first
382382 promptHash := d .GetPromptHash (req )
383383 promptHashHex := d .GetPromptHashHex (promptHash )
384384 query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
385385 tokensList , err := d .query (query , nTokens )
386386
387- if err != nil || len (tokensList ) == 0 {
388- // if query by prompt hash fails, fallback to query by number of tokens
389- query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv .Itoa (nTokens ) + ";"
390- tokensList , err = d .query (query , nTokens )
387+ // filter out results according to finish reason
388+ var filteredTokensList [][]string
389+ if finishReason != LengthFinishReason && finishReason != StopFinishReason {
390+ d .Logger .Error (errors .New ("unknown finish reason" ), "Unexpected finish reason" , "reason" , finishReason )
391+ }
392+ for _ , tokens := range tokensList {
393+ if finishReason == StopFinishReason && len (tokens ) <= nTokens {
394+ filteredTokensList = append (filteredTokensList , tokens )
395+ } else if finishReason == LengthFinishReason && len (tokens ) == nTokens {
396+ filteredTokensList = append (filteredTokensList , tokens )
397+ }
398+ }
399+ tokensList = filteredTokensList
400+
401+ if err != nil || len (filteredTokensList ) == 0 {
402+ switch finishReason {
403+ case LengthFinishReason :
404+ query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv .Itoa (nTokens ) + ";"
405+ tokensList , err = d .query (query , nTokens )
406+ case StopFinishReason :
407+ query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv .Itoa (nTokens ) + ";"
408+ tokensList , err = d .query (query , nTokens )
409+ }
391410 }
392411
393412 if err != nil || len (tokensList ) == 0 {
0 commit comments