Skip to content

Commit 84f3ad7

Browse files
committed
Return random from dataset if prmopt hash does not hit
Signed-off-by: Qifan Deng <[email protected]>
1 parent 8c94372 commit 84f3ad7

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

pkg/dataset/custom_dataset.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"os"
3030
"os/signal"
3131
"path/filepath"
32+
"strconv"
3233
"syscall"
3334
"time"
3435

@@ -359,33 +360,43 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
359360
return tokens, finishReason, err
360361
}
361362

362-
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
363-
promptHash := d.GetPromptHash(req)
364-
promptHashHex := d.GetPromptHashHex(promptHash)
365-
rows, err := d.db.Query("SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';")
363+
func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
364+
rows, err := d.db.Query(query)
366365
if err != nil {
367366
if !d.hasWarned {
368-
d.Logger.Error(err, "failed to query database. Ensure the prompt hash exists in the dataset. Will generate random tokens instead.")
367+
d.Logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.")
369368
d.hasWarned = true
370369
}
371-
return GenPresetRandomTokens(nTokens), nil
370+
return [][]string{GenPresetRandomTokens(nTokens)}, nil
372371
}
373372
defer func() {
374373
if cerr := rows.Close(); cerr != nil {
375374
d.Logger.Error(cerr, "failed to close rows after query")
376375
}
377376
}()
377+
return unmarshalAllRecords(rows)
378+
}
378379

379-
tokensList, err := unmarshalAllRecords(rows)
380-
if err != nil {
381-
d.Logger.Error(err, "failed to unmarshal records from database")
382-
return GenPresetRandomTokens(nTokens), nil
380+
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int) ([]string, error) {
381+
// query by prompt hash first
382+
promptHash := d.GetPromptHash(req)
383+
promptHashHex := d.GetPromptHashHex(promptHash)
384+
query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
385+
tokensList, err := d.query(query, nTokens)
386+
387+
if err != nil || len(tokensList) == 0 {
388+
// if query by prompt hash fails, fallback to query by number of tokens
389+
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
390+
tokensList, err = d.query(query, nTokens)
383391
}
384392

385-
if len(tokensList) == 0 {
393+
if err != nil || len(tokensList) == 0 {
394+
// if both queries fail or return no results, generate random tokens
386395
return GenPresetRandomTokens(nTokens), nil
387396
}
388-
d.hasWarned = false
397+
if d.hasWarned {
398+
d.hasWarned = false
399+
}
389400
randIndex := common.RandomInt(0, len(tokensList)-1)
390401
return tokensList[randIndex], nil
391402
}

0 commit comments

Comments
 (0)