@@ -29,6 +29,7 @@ import (
2929	"os" 
3030	"os/signal" 
3131	"path/filepath" 
32+ 	"strconv" 
3233	"syscall" 
3334	"time" 
3435
@@ -359,33 +360,43 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
359360	return  tokens , finishReason , err 
360361}
361362
362- func  (d  * CustomDataset ) GenerateTokens (req  openaiserverapi.CompletionRequest , nTokens  int ) ([]string , error ) {
363- 	promptHash  :=  d .GetPromptHash (req )
364- 	promptHashHex  :=  d .GetPromptHashHex (promptHash )
365- 	rows , err  :=  d .db .Query ("SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  promptHashCol  +  "=X'"  +  promptHashHex  +  "';" )
363+ func  (d  * CustomDataset ) query (query  string , nTokens  int ) ([][]string , error ) {
364+ 	rows , err  :=  d .db .Query (query )
366365	if  err  !=  nil  {
367366		if  ! d .hasWarned  {
368- 			d .Logger .Error (err , "failed  to query database. Ensure the prompt hash exists in the dataset . Will generate random tokens instead." )
367+ 			d .Logger .Error (err , "Failed  to query database. Ensure dataset file is still valid . Will generate random tokens instead." )
369368			d .hasWarned  =  true 
370369		}
371- 		return  GenPresetRandomTokens (nTokens ), nil 
370+ 		return  [][] string { GenPresetRandomTokens (nTokens )} , nil 
372371	}
373372	defer  func () {
374373		if  cerr  :=  rows .Close (); cerr  !=  nil  {
375374			d .Logger .Error (cerr , "failed to close rows after query" )
376375		}
377376	}()
377+ 	return  unmarshalAllRecords (rows )
378+ }
378379
379- 	tokensList , err  :=  unmarshalAllRecords (rows )
380- 	if  err  !=  nil  {
381- 		d .Logger .Error (err , "failed to unmarshal records from database" )
382- 		return  GenPresetRandomTokens (nTokens ), nil 
380+ func  (d  * CustomDataset ) GenerateTokens (req  openaiserverapi.CompletionRequest , nTokens  int ) ([]string , error ) {
381+ 	// query by prompt hash first 
382+ 	promptHash  :=  d .GetPromptHash (req )
383+ 	promptHashHex  :=  d .GetPromptHashHex (promptHash )
384+ 	query  :=  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  promptHashCol  +  "=X'"  +  promptHashHex  +  "';" 
385+ 	tokensList , err  :=  d .query (query , nTokens )
386+ 
387+ 	if  err  !=  nil  ||  len (tokensList ) ==  0  {
388+ 		// if query by prompt hash fails, fallback to query by number of tokens 
389+ 		query  =  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  nGenTokensCol  +  "="  +  strconv .Itoa (nTokens ) +  ";" 
390+ 		tokensList , err  =  d .query (query , nTokens )
383391	}
384392
385- 	if  len (tokensList ) ==  0  {
393+ 	if  err  !=  nil  ||  len (tokensList ) ==  0  {
394+ 		// if both queries fail or return no results, generate random tokens 
386395		return  GenPresetRandomTokens (nTokens ), nil 
387396	}
388- 	d .hasWarned  =  false 
397+ 	if  d .hasWarned  {
398+ 		d .hasWarned  =  false 
399+ 	}
389400	randIndex  :=  common .RandomInt (0 , len (tokensList )- 1 )
390401	return  tokensList [randIndex ], nil 
391402}
0 commit comments