Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions pkg/common/tools_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func CreateToolCalls(
tools []openaiserverapi.Tool,
toolChoice openaiserverapi.ToolChoice,
config *Configuration,
random *Random,
) ([]openaiserverapi.ToolCall, int, error) {
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
if len(availableTools) == 0 {
Expand All @@ -102,7 +103,7 @@ func CreateToolCalls(
numberOfCalls := minCalls
if len(availableTools) > minCalls {
// Randomly decide how many tools to call, between minCalls and the total available.
numberOfCalls = RandomInt(minCalls, len(availableTools))
numberOfCalls = random.RandomInt(minCalls, len(availableTools))
}

if numberOfCalls == 0 {
Expand All @@ -114,11 +115,11 @@ func CreateToolCalls(
// Randomly choose which tool to call. We may call the same tool more than once.
index := 0
if len(availableTools) > 1 {
index = RandomInt(0, len(availableTools)-1)
index = random.RandomInt(0, len(availableTools)-1)
}
chosenTool := availableTools[index]

args, err := generateToolArguments(chosenTool, config)
args, err := generateToolArguments(chosenTool, config, random)
if err != nil {
return nil, 0, err
}
Expand All @@ -133,7 +134,7 @@ func CreateToolCalls(
TokenizedArguments: Tokenize(string(argsJson)),
Name: &chosenTool.Function.Name,
},
ID: "chatcmpl-tool-" + RandomNumericString(10),
ID: "chatcmpl-tool-" + random.RandomNumericString(10),
Type: "function",
Index: i,
}
Expand Down Expand Up @@ -188,18 +189,18 @@ func getRequiredAsMap(property map[string]any) map[string]struct{} {
return required
}

func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) {
func generateToolArguments(tool openaiserverapi.Tool, config *Configuration, random *Random) (map[string]any, error) {
arguments := make(map[string]any)
properties, _ := tool.Function.Parameters["properties"].(map[string]any)

required := getRequiredAsMap(tool.Function.Parameters)

for param, property := range properties {
_, paramIsRequired := required[param]
if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) {
if !paramIsRequired && !random.RandomBool(config.ToolCallNotRequiredParamProbability) {
continue
}
arg, err := createArgument(property, config)
arg, err := createArgument(property, config, random)
if err != nil {
return nil, err
}
Expand All @@ -209,7 +210,7 @@ func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (ma
return arguments, nil
}

func createArgument(property any, config *Configuration) (any, error) {
func createArgument(property any, config *Configuration, random *Random) (any, error) {
propertyMap, _ := property.(map[string]any)
paramType := propertyMap["type"]

Expand All @@ -218,20 +219,20 @@ func createArgument(property any, config *Configuration) (any, error) {
if ok {
enumArray, ok := enum.([]any)
if ok && len(enumArray) > 0 {
index := RandomInt(0, len(enumArray)-1)
index := random.RandomInt(0, len(enumArray)-1)
return enumArray[index], nil
}
}

switch paramType {
case "string":
return getStringArgument(), nil
return getStringArgument(random), nil
case "integer":
return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
return random.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
case "number":
return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
return random.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
case "boolean":
return FlipCoin(), nil
return random.FlipCoin(), nil
case "array":
items := propertyMap["items"]
itemsMap := items.(map[string]any)
Expand All @@ -246,10 +247,10 @@ func createArgument(property any, config *Configuration) (any, error) {
if minItems > maxItems {
return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems)
}
numberOfElements := RandomInt(minItems, maxItems)
numberOfElements := random.RandomInt(minItems, maxItems)
array := make([]any, numberOfElements)
for i := range numberOfElements {
elem, err := createArgument(itemsMap, config)
elem, err := createArgument(itemsMap, config, random)
if err != nil {
return nil, err
}
Expand All @@ -262,10 +263,10 @@ func createArgument(property any, config *Configuration) (any, error) {
object := make(map[string]interface{})
for fieldName, fieldProperties := range objectProperties {
_, fieldIsRequired := required[fieldName]
if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
if !fieldIsRequired && !random.RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
continue
}
fieldValue, err := createArgument(fieldProperties, config)
fieldValue, err := createArgument(fieldProperties, config, random)
if err != nil {
return nil, err
}
Expand All @@ -277,8 +278,8 @@ func createArgument(property any, config *Configuration) (any, error) {
}
}

func getStringArgument() string {
index := RandomInt(0, len(fakeStringArguments)-1)
func getStringArgument(random *Random) string {
index := random.RandomInt(0, len(fakeStringArguments)-1)
return fakeStringArguments[index]
}

Expand Down
91 changes: 52 additions & 39 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,63 +49,66 @@ func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode
return isValid, completionTokens, totalTokens
}

func RandomNumericString(length int) string {
digits := "0123456789"
result := make([]byte, length)
for i := 0; i < length; i++ {
num := RandomInt(0, 9)
result[i] = digits[num]
}
return string(result)
type Random struct {
randomGenerator *rand.Rand
randMutex sync.Mutex
}

var randomGenerator *rand.Rand
var randMutex sync.Mutex

func InitRandom(seed int64) {
func NewRandom(seed int64) *Random {
src := rand.NewSource(seed)
randomGenerator = rand.New(src)
uuid.SetRand(randomGenerator)
randomGenerator := rand.New(src)
uuid.SetRand(rand.New(rand.NewSource(seed)))
return &Random{randomGenerator: randomGenerator}
}

// Returns an integer between min and max (included)
func RandomInt(min int, max int) int {
randMutex.Lock()
defer randMutex.Unlock()
return randomGenerator.Intn(max-min+1) + min
func (r *Random) RandomInt(min int, max int) int {
r.randMutex.Lock()
defer r.randMutex.Unlock()

return r.randomGenerator.Intn(max-min+1) + min
}

// Returns true or false randomly
func FlipCoin() bool {
return RandomInt(0, 1) != 0
func (r *Random) FlipCoin() bool {
return r.RandomInt(0, 1) != 0
}

// probability is an integer between 0 and 100
func RandomBool(probability int) bool {
randMutex.Lock()
defer randMutex.Unlock()
return randomGenerator.Float64() < float64(probability)/100
func (r *Random) RandomBool(probability int) bool {
r.randMutex.Lock()
defer r.randMutex.Unlock()

return r.randomGenerator.Float64() < float64(probability)/100
}

// Returns a random float64 in the range [min, max)
func RandomFloat(min float64, max float64) float64 {
randMutex.Lock()
defer randMutex.Unlock()
return randomGenerator.Float64()*(max-min) + min
func (r *Random) RandomFloat(min float64, max float64) float64 {
r.randMutex.Lock()
defer r.randMutex.Unlock()

return r.randomGenerator.Float64()*(max-min) + min
}

// Returns a normally distributed int
// If the generated value differs by more than 70% from mean, the returned
// value will be 70% of mean
func RandomNorm(mean int, stddev int) int {
// Returns a normally distributed float64
func (r *Random) RandomNorm(mean int, stddev int) float64 {
if stddev == 0 {
return mean
return float64(mean)
}
randMutex.Lock()
defer randMutex.Unlock()
r.randMutex.Lock()
defer r.randMutex.Unlock()

mean_ := float64(mean)
stddev_ := float64(stddev)
value := randomGenerator.NormFloat64()*stddev_ + mean_
return r.randomGenerator.NormFloat64()*stddev_ + mean_
}

// Returns a normally distributed int
// If the generated value differs by more than 70% from mean, the returned
// value will be 70% of mean
func (r *Random) RandomNormTruncated(mean int, stddev int) int {
value := r.RandomNorm(mean, stddev)
mean_ := float64(mean)
if value < 0.3*mean_ {
value = 0.3 * mean_
} else if value > 1.7*mean_ {
Expand All @@ -115,12 +118,22 @@ func RandomNorm(mean int, stddev int) int {
}

// GenerateUUIDString generates a UUID string under a lock
func GenerateUUIDString() string {
randMutex.Lock()
defer randMutex.Unlock()
func (r *Random) GenerateUUIDString() string {
r.randMutex.Lock()
defer r.randMutex.Unlock()
return uuid.NewString()
}

func (r *Random) RandomNumericString(length int) string {
digits := "0123456789"
result := make([]byte, length)
for i := 0; i < length; i++ {
num := r.RandomInt(0, 9)
result[i] = digits[num]
}
return string(result)
}

// Regular expression for the response tokenization
var re *regexp.Regexp

Expand Down
6 changes: 0 additions & 6 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,11 @@ limitations under the License.
package common

import (
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("Utils", Ordered, func() {
BeforeAll(func() {
InitRandom(time.Now().UnixNano())
})

Context("validateContextWindow", func() {
It("should pass when total tokens are within limit", func() {
promptTokens := 100
Expand Down
23 changes: 12 additions & 11 deletions pkg/dataset/custom_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,23 +435,23 @@ func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string {
}

// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) {
if mode == common.ModeEcho {
return d.echo(req)
}
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason)
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random)
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason, random)
return tokens, finishReason, err
}

func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
func (d *CustomDataset) query(query string, nTokens int, random *common.Random) ([][]string, error) {
rows, err := d.db.Query(query)
if err != nil {
if !d.hasWarned {
d.logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.")
d.hasWarned = true
}
return [][]string{GenPresetRandomTokens(nTokens)}, nil
return [][]string{GenPresetRandomTokens(random, nTokens)}, nil
}
defer func() {
if cerr := rows.Close(); cerr != nil {
Expand All @@ -461,12 +461,13 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
return unmarshalAllRecords(rows)
}

func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) {
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string,
random *common.Random) ([]string, error) {
// query by prompt hash first
promptHash := d.GetPromptHash(req)
promptHashHex := d.GetPromptHashHex(promptHash)
query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
tokensList, err := d.query(query, nTokens)
tokensList, err := d.query(query, nTokens, random)

// filter out results according to finish reason
var filteredTokensList [][]string
Expand All @@ -486,20 +487,20 @@ func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nT
switch finishReason {
case LengthFinishReason:
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
tokensList, err = d.query(query, nTokens)
tokensList, err = d.query(query, nTokens, random)
case StopFinishReason:
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";"
tokensList, err = d.query(query, nTokens)
tokensList, err = d.query(query, nTokens, random)
}
}

if err != nil || len(tokensList) == 0 {
// if both queries fail or return no results, generate random tokens
return GenPresetRandomTokens(nTokens), nil
return GenPresetRandomTokens(random, nTokens), nil
}
if d.hasWarned {
d.hasWarned = false
}
randIndex := common.RandomInt(0, len(tokensList)-1)
randIndex := random.RandomInt(0, len(tokensList)-1)
return tokensList[randIndex], nil
}
9 changes: 5 additions & 4 deletions pkg/dataset/custom_dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ var _ = Describe("CustomDataset", Ordered, func() {
pathToInvalidTableDB string
pathToInvalidColumnDB string
pathToInvalidTypeDB string
random *common.Random
)

BeforeAll(func() {
common.InitRandom(time.Now().UnixNano())
random = common.NewRandom(time.Now().UnixNano())
})

BeforeEach(func() {
Expand Down Expand Up @@ -182,7 +183,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
Expect(err).NotTo(HaveOccurred())
Expect(finishReason).To(Equal(StopFinishReason))
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
Expand All @@ -196,7 +197,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
Prompt: testPrompt,
MaxTokens: &n,
}
tokens, _, err := dataset.GetTokens(req, common.ModeRandom)
tokens, _, err := dataset.GetTokens(req, common.ModeRandom, random)
Expect(err).NotTo(HaveOccurred())
Expect(len(tokens)).To(BeNumerically("<=", 2))
})
Expand All @@ -208,7 +209,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
req := &openaiserverapi.TextCompletionRequest{
Prompt: testPrompt,
}
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
Expect(err).NotTo(HaveOccurred())
Expect(finishReason).To(Equal(StopFinishReason))
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
Expand Down
Loading