@@ -29,6 +29,7 @@ import (
2929 "os"
3030 "os/signal"
3131 "path/filepath"
32+ "strconv"
3233 "syscall"
3334 "time"
3435
@@ -359,33 +360,43 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
359360 return tokens , finishReason , err
360361}
361362
362- func (d * CustomDataset ) GenerateTokens (req openaiserverapi.CompletionRequest , nTokens int ) ([]string , error ) {
363- promptHash := d .GetPromptHash (req )
364- promptHashHex := d .GetPromptHashHex (promptHash )
365- rows , err := d .db .Query ("SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" )
363+ func (d * CustomDataset ) query (query string , nTokens int ) ([][]string , error ) {
364+ rows , err := d .db .Query (query )
366365 if err != nil {
367366 if ! d .hasWarned {
368- d .Logger .Error (err , "failed to query database. Ensure the prompt hash exists in the dataset . Will generate random tokens instead." )
367+ d .Logger .Error (err , "Failed to query database. Ensure dataset file is still valid . Will generate random tokens instead." )
369368 d .hasWarned = true
370369 }
371- return GenPresetRandomTokens (nTokens ), nil
370+ return [][] string { GenPresetRandomTokens (nTokens )} , nil
372371 }
373372 defer func () {
374373 if cerr := rows .Close (); cerr != nil {
375374 d .Logger .Error (cerr , "failed to close rows after query" )
376375 }
377376 }()
377+ return unmarshalAllRecords (rows )
378+ }
378379
379- tokensList , err := unmarshalAllRecords (rows )
380- if err != nil {
381- d .Logger .Error (err , "failed to unmarshal records from database" )
382- return GenPresetRandomTokens (nTokens ), nil
380+ func (d * CustomDataset ) GenerateTokens (req openaiserverapi.CompletionRequest , nTokens int ) ([]string , error ) {
381+ // query by prompt hash first
382+ promptHash := d .GetPromptHash (req )
383+ promptHashHex := d .GetPromptHashHex (promptHash )
384+ query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
385+ tokensList , err := d .query (query , nTokens )
386+
387+ if err != nil || len (tokensList ) == 0 {
388+ // if query by prompt hash fails, fallback to query by number of tokens
389+ query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv .Itoa (nTokens ) + ";"
390+ tokensList , err = d .query (query , nTokens )
383391 }
384392
385- if len (tokensList ) == 0 {
393+ if err != nil || len (tokensList ) == 0 {
394+ // if both queries fail or return no results, generate random tokens
386395 return GenPresetRandomTokens (nTokens ), nil
387396 }
388- d .hasWarned = false
397+ if d .hasWarned {
398+ d .hasWarned = false
399+ }
389400 randIndex := common .RandomInt (0 , len (tokensList )- 1 )
390401 return tokensList [randIndex ], nil
391402}
0 commit comments