@@ -356,7 +356,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
356356		return  d .echo (req )
357357	}
358358	nTokensToGen , finishReason  :=  howManyTokensToGen (d .extractMaxTokens (req ), req .GetIgnoreEOS ())
359- 	tokens , err  :=  d .GenerateTokens (req , nTokensToGen )
359+ 	tokens , err  :=  d .GenerateTokens (req , nTokensToGen ,  finishReason )
360360	return  tokens , finishReason , err 
361361}
362362
@@ -377,17 +377,36 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
377377	return  unmarshalAllRecords (rows )
378378}
379379
380- func  (d  * CustomDataset ) GenerateTokens (req  openaiserverapi.CompletionRequest , nTokens  int ) ([]string , error ) {
380+ func  (d  * CustomDataset ) GenerateTokens (req  openaiserverapi.CompletionRequest , nTokens  int ,  finishReason   string ) ([]string , error ) {
381381	// query by prompt hash first 
382382	promptHash  :=  d .GetPromptHash (req )
383383	promptHashHex  :=  d .GetPromptHashHex (promptHash )
384384	query  :=  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  promptHashCol  +  "=X'"  +  promptHashHex  +  "';" 
385385	tokensList , err  :=  d .query (query , nTokens )
386386
387- 	if  err  !=  nil  ||  len (tokensList ) ==  0  {
388- 		// if query by prompt hash fails, fallback to query by number of tokens 
389- 		query  =  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  nGenTokensCol  +  "="  +  strconv .Itoa (nTokens ) +  ";" 
390- 		tokensList , err  =  d .query (query , nTokens )
387+ 	// filter out results according to finish reason 
388+ 	var  filteredTokensList  [][]string 
389+ 	if  finishReason  !=  LengthFinishReason  &&  finishReason  !=  StopFinishReason  {
390+ 		d .Logger .Error (errors .New ("unknown finish reason" ), "Unexpected finish reason" , "reason" , finishReason )
391+ 	}
392+ 	for  _ , tokens  :=  range  tokensList  {
393+ 		if  finishReason  ==  StopFinishReason  &&  len (tokens ) <=  nTokens  {
394+ 			filteredTokensList  =  append (filteredTokensList , tokens )
395+ 		} else  if  finishReason  ==  LengthFinishReason  &&  len (tokens ) ==  nTokens  {
396+ 			filteredTokensList  =  append (filteredTokensList , tokens )
397+ 		}
398+ 	}
399+ 	tokensList  =  filteredTokensList 
400+ 
401+ 	if  err  !=  nil  ||  len (filteredTokensList ) ==  0  {
402+ 		switch  finishReason  {
403+ 		case  LengthFinishReason :
404+ 			query  =  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  nGenTokensCol  +  "="  +  strconv .Itoa (nTokens ) +  ";" 
405+ 			tokensList , err  =  d .query (query , nTokens )
406+ 		case  StopFinishReason :
407+ 			query  =  "SELECT "  +  genTokensCol  +  " FROM "  +  tableName  +  " WHERE "  +  nGenTokensCol  +  "<="  +  strconv .Itoa (nTokens ) +  ";" 
408+ 			tokensList , err  =  d .query (query , nTokens )
409+ 		}
391410	}
392411
393412	if  err  !=  nil  ||  len (tokensList ) ==  0  {
0 commit comments