Skip to content

Commit 5ca355b

Browse files
authored
Changed random from static to a field in the simulator (#238)
Signed-off-by: irar2 <[email protected]>
1 parent de71f5d commit 5ca355b

File tree

16 files changed

+152
-139
lines changed

16 files changed

+152
-139
lines changed

pkg/common/tools_utils.go

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func CreateToolCalls(
9292
tools []openaiserverapi.Tool,
9393
toolChoice openaiserverapi.ToolChoice,
9494
config *Configuration,
95+
random *Random,
9596
) ([]openaiserverapi.ToolCall, int, error) {
9697
generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) {
9798
if len(availableTools) == 0 {
@@ -102,7 +103,7 @@ func CreateToolCalls(
102103
numberOfCalls := minCalls
103104
if len(availableTools) > minCalls {
104105
// Randomly decide how many tools to call, between minCalls and the total available.
105-
numberOfCalls = RandomInt(minCalls, len(availableTools))
106+
numberOfCalls = random.RandomInt(minCalls, len(availableTools))
106107
}
107108

108109
if numberOfCalls == 0 {
@@ -114,11 +115,11 @@ func CreateToolCalls(
114115
// Randomly choose which tool to call. We may call the same tool more than once.
115116
index := 0
116117
if len(availableTools) > 1 {
117-
index = RandomInt(0, len(availableTools)-1)
118+
index = random.RandomInt(0, len(availableTools)-1)
118119
}
119120
chosenTool := availableTools[index]
120121

121-
args, err := generateToolArguments(chosenTool, config)
122+
args, err := generateToolArguments(chosenTool, config, random)
122123
if err != nil {
123124
return nil, 0, err
124125
}
@@ -133,7 +134,7 @@ func CreateToolCalls(
133134
TokenizedArguments: Tokenize(string(argsJson)),
134135
Name: &chosenTool.Function.Name,
135136
},
136-
ID: "chatcmpl-tool-" + RandomNumericString(10),
137+
ID: "chatcmpl-tool-" + random.RandomNumericString(10),
137138
Type: "function",
138139
Index: i,
139140
}
@@ -188,18 +189,18 @@ func getRequiredAsMap(property map[string]any) map[string]struct{} {
188189
return required
189190
}
190191

191-
func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) {
192+
func generateToolArguments(tool openaiserverapi.Tool, config *Configuration, random *Random) (map[string]any, error) {
192193
arguments := make(map[string]any)
193194
properties, _ := tool.Function.Parameters["properties"].(map[string]any)
194195

195196
required := getRequiredAsMap(tool.Function.Parameters)
196197

197198
for param, property := range properties {
198199
_, paramIsRequired := required[param]
199-
if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) {
200+
if !paramIsRequired && !random.RandomBool(config.ToolCallNotRequiredParamProbability) {
200201
continue
201202
}
202-
arg, err := createArgument(property, config)
203+
arg, err := createArgument(property, config, random)
203204
if err != nil {
204205
return nil, err
205206
}
@@ -209,7 +210,7 @@ func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (ma
209210
return arguments, nil
210211
}
211212

212-
func createArgument(property any, config *Configuration) (any, error) {
213+
func createArgument(property any, config *Configuration, random *Random) (any, error) {
213214
propertyMap, _ := property.(map[string]any)
214215
paramType := propertyMap["type"]
215216

@@ -218,20 +219,20 @@ func createArgument(property any, config *Configuration) (any, error) {
218219
if ok {
219220
enumArray, ok := enum.([]any)
220221
if ok && len(enumArray) > 0 {
221-
index := RandomInt(0, len(enumArray)-1)
222+
index := random.RandomInt(0, len(enumArray)-1)
222223
return enumArray[index], nil
223224
}
224225
}
225226

226227
switch paramType {
227228
case "string":
228-
return getStringArgument(), nil
229+
return getStringArgument(random), nil
229230
case "integer":
230-
return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
231+
return random.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
231232
case "number":
232-
return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
233+
return random.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
233234
case "boolean":
234-
return FlipCoin(), nil
235+
return random.FlipCoin(), nil
235236
case "array":
236237
items := propertyMap["items"]
237238
itemsMap := items.(map[string]any)
@@ -246,10 +247,10 @@ func createArgument(property any, config *Configuration) (any, error) {
246247
if minItems > maxItems {
247248
return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems)
248249
}
249-
numberOfElements := RandomInt(minItems, maxItems)
250+
numberOfElements := random.RandomInt(minItems, maxItems)
250251
array := make([]any, numberOfElements)
251252
for i := range numberOfElements {
252-
elem, err := createArgument(itemsMap, config)
253+
elem, err := createArgument(itemsMap, config, random)
253254
if err != nil {
254255
return nil, err
255256
}
@@ -262,10 +263,10 @@ func createArgument(property any, config *Configuration) (any, error) {
262263
object := make(map[string]interface{})
263264
for fieldName, fieldProperties := range objectProperties {
264265
_, fieldIsRequired := required[fieldName]
265-
if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
266+
if !fieldIsRequired && !random.RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
266267
continue
267268
}
268-
fieldValue, err := createArgument(fieldProperties, config)
269+
fieldValue, err := createArgument(fieldProperties, config, random)
269270
if err != nil {
270271
return nil, err
271272
}
@@ -277,8 +278,8 @@ func createArgument(property any, config *Configuration) (any, error) {
277278
}
278279
}
279280

280-
func getStringArgument() string {
281-
index := RandomInt(0, len(fakeStringArguments)-1)
281+
func getStringArgument(random *Random) string {
282+
index := random.RandomInt(0, len(fakeStringArguments)-1)
282283
return fakeStringArguments[index]
283284
}
284285

pkg/common/utils.go

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,63 +49,66 @@ func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode
4949
return isValid, completionTokens, totalTokens
5050
}
5151

52-
func RandomNumericString(length int) string {
53-
digits := "0123456789"
54-
result := make([]byte, length)
55-
for i := 0; i < length; i++ {
56-
num := RandomInt(0, 9)
57-
result[i] = digits[num]
58-
}
59-
return string(result)
52+
type Random struct {
53+
randomGenerator *rand.Rand
54+
randMutex sync.Mutex
6055
}
6156

62-
var randomGenerator *rand.Rand
63-
var randMutex sync.Mutex
64-
65-
func InitRandom(seed int64) {
57+
func NewRandom(seed int64) *Random {
6658
src := rand.NewSource(seed)
67-
randomGenerator = rand.New(src)
68-
uuid.SetRand(randomGenerator)
59+
randomGenerator := rand.New(src)
60+
uuid.SetRand(rand.New(rand.NewSource(seed)))
61+
return &Random{randomGenerator: randomGenerator}
6962
}
7063

7164
// Returns an integer between min and max (included)
72-
func RandomInt(min int, max int) int {
73-
randMutex.Lock()
74-
defer randMutex.Unlock()
75-
return randomGenerator.Intn(max-min+1) + min
65+
func (r *Random) RandomInt(min int, max int) int {
66+
r.randMutex.Lock()
67+
defer r.randMutex.Unlock()
68+
69+
return r.randomGenerator.Intn(max-min+1) + min
7670
}
7771

7872
// Returns true or false randomly
79-
func FlipCoin() bool {
80-
return RandomInt(0, 1) != 0
73+
func (r *Random) FlipCoin() bool {
74+
return r.RandomInt(0, 1) != 0
8175
}
8276

8377
// probability is an integer between 0 and 100
84-
func RandomBool(probability int) bool {
85-
randMutex.Lock()
86-
defer randMutex.Unlock()
87-
return randomGenerator.Float64() < float64(probability)/100
78+
func (r *Random) RandomBool(probability int) bool {
79+
r.randMutex.Lock()
80+
defer r.randMutex.Unlock()
81+
82+
return r.randomGenerator.Float64() < float64(probability)/100
8883
}
8984

9085
// Returns a random float64 in the range [min, max)
91-
func RandomFloat(min float64, max float64) float64 {
92-
randMutex.Lock()
93-
defer randMutex.Unlock()
94-
return randomGenerator.Float64()*(max-min) + min
86+
func (r *Random) RandomFloat(min float64, max float64) float64 {
87+
r.randMutex.Lock()
88+
defer r.randMutex.Unlock()
89+
90+
return r.randomGenerator.Float64()*(max-min) + min
9591
}
9692

97-
// Returns a normally distributed int
98-
// If the generated value differs by more than 70% from mean, the returned
99-
// value will be 70% of mean
100-
func RandomNorm(mean int, stddev int) int {
93+
// Returns a normally distributed float64
94+
func (r *Random) RandomNorm(mean int, stddev int) float64 {
10195
if stddev == 0 {
102-
return mean
96+
return float64(mean)
10397
}
104-
randMutex.Lock()
105-
defer randMutex.Unlock()
98+
r.randMutex.Lock()
99+
defer r.randMutex.Unlock()
100+
106101
mean_ := float64(mean)
107102
stddev_ := float64(stddev)
108-
value := randomGenerator.NormFloat64()*stddev_ + mean_
103+
return r.randomGenerator.NormFloat64()*stddev_ + mean_
104+
}
105+
106+
// Returns a normally distributed int
107+
// If the generated value differs by more than 70% from mean, the returned
108+
// value will be 70% of mean
109+
func (r *Random) RandomNormTruncated(mean int, stddev int) int {
110+
value := r.RandomNorm(mean, stddev)
111+
mean_ := float64(mean)
109112
if value < 0.3*mean_ {
110113
value = 0.3 * mean_
111114
} else if value > 1.7*mean_ {
@@ -115,12 +118,22 @@ func RandomNorm(mean int, stddev int) int {
115118
}
116119

117120
// GenerateUUIDString generates a UUID string under a lock
118-
func GenerateUUIDString() string {
119-
randMutex.Lock()
120-
defer randMutex.Unlock()
121+
func (r *Random) GenerateUUIDString() string {
122+
r.randMutex.Lock()
123+
defer r.randMutex.Unlock()
121124
return uuid.NewString()
122125
}
123126

127+
func (r *Random) RandomNumericString(length int) string {
128+
digits := "0123456789"
129+
result := make([]byte, length)
130+
for i := 0; i < length; i++ {
131+
num := r.RandomInt(0, 9)
132+
result[i] = digits[num]
133+
}
134+
return string(result)
135+
}
136+
124137
// Regular expression for the response tokenization
125138
var re *regexp.Regexp
126139

pkg/common/utils_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,11 @@ limitations under the License.
1717
package common
1818

1919
import (
20-
"time"
21-
2220
. "github.com/onsi/ginkgo/v2"
2321
. "github.com/onsi/gomega"
2422
)
2523

2624
var _ = Describe("Utils", Ordered, func() {
27-
BeforeAll(func() {
28-
InitRandom(time.Now().UnixNano())
29-
})
30-
3125
Context("validateContextWindow", func() {
3226
It("should pass when total tokens are within limit", func() {
3327
promptTokens := 100

pkg/dataset/custom_dataset.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -435,23 +435,23 @@ func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string {
435435
}
436436

437437
// GetTokens returns tokens and finishReason for the given request and mode (echo or random)
438-
func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) {
438+
func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) {
439439
if mode == common.ModeEcho {
440440
return d.echo(req)
441441
}
442-
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
443-
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason)
442+
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random)
443+
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason, random)
444444
return tokens, finishReason, err
445445
}
446446

447-
func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
447+
func (d *CustomDataset) query(query string, nTokens int, random *common.Random) ([][]string, error) {
448448
rows, err := d.db.Query(query)
449449
if err != nil {
450450
if !d.hasWarned {
451451
d.logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.")
452452
d.hasWarned = true
453453
}
454-
return [][]string{GenPresetRandomTokens(nTokens)}, nil
454+
return [][]string{GenPresetRandomTokens(random, nTokens)}, nil
455455
}
456456
defer func() {
457457
if cerr := rows.Close(); cerr != nil {
@@ -461,12 +461,13 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) {
461461
return unmarshalAllRecords(rows)
462462
}
463463

464-
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) {
464+
func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string,
465+
random *common.Random) ([]string, error) {
465466
// query by prompt hash first
466467
promptHash := d.GetPromptHash(req)
467468
promptHashHex := d.GetPromptHashHex(promptHash)
468469
query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';"
469-
tokensList, err := d.query(query, nTokens)
470+
tokensList, err := d.query(query, nTokens, random)
470471

471472
// filter out results according to finish reason
472473
var filteredTokensList [][]string
@@ -486,20 +487,20 @@ func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nT
486487
switch finishReason {
487488
case LengthFinishReason:
488489
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";"
489-
tokensList, err = d.query(query, nTokens)
490+
tokensList, err = d.query(query, nTokens, random)
490491
case StopFinishReason:
491492
query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";"
492-
tokensList, err = d.query(query, nTokens)
493+
tokensList, err = d.query(query, nTokens, random)
493494
}
494495
}
495496

496497
if err != nil || len(tokensList) == 0 {
497498
// if both queries fail or return no results, generate random tokens
498-
return GenPresetRandomTokens(nTokens), nil
499+
return GenPresetRandomTokens(random, nTokens), nil
499500
}
500501
if d.hasWarned {
501502
d.hasWarned = false
502503
}
503-
randIndex := common.RandomInt(0, len(tokensList)-1)
504+
randIndex := random.RandomInt(0, len(tokensList)-1)
504505
return tokensList[randIndex], nil
505506
}

pkg/dataset/custom_dataset_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ var _ = Describe("CustomDataset", Ordered, func() {
4646
pathToInvalidTableDB string
4747
pathToInvalidColumnDB string
4848
pathToInvalidTypeDB string
49+
random *common.Random
4950
)
5051

5152
BeforeAll(func() {
52-
common.InitRandom(time.Now().UnixNano())
53+
random = common.NewRandom(time.Now().UnixNano())
5354
})
5455

5556
BeforeEach(func() {
@@ -182,7 +183,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
182183
req := &openaiserverapi.TextCompletionRequest{
183184
Prompt: testPrompt,
184185
}
185-
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
186+
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
186187
Expect(err).NotTo(HaveOccurred())
187188
Expect(finishReason).To(Equal(StopFinishReason))
188189
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))
@@ -196,7 +197,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
196197
Prompt: testPrompt,
197198
MaxTokens: &n,
198199
}
199-
tokens, _, err := dataset.GetTokens(req, common.ModeRandom)
200+
tokens, _, err := dataset.GetTokens(req, common.ModeRandom, random)
200201
Expect(err).NotTo(HaveOccurred())
201202
Expect(len(tokens)).To(BeNumerically("<=", 2))
202203
})
@@ -208,7 +209,7 @@ var _ = Describe("CustomDataset", Ordered, func() {
208209
req := &openaiserverapi.TextCompletionRequest{
209210
Prompt: testPrompt,
210211
}
211-
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
212+
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random)
212213
Expect(err).NotTo(HaveOccurred())
213214
Expect(finishReason).To(Equal(StopFinishReason))
214215
Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"}))

0 commit comments

Comments
 (0)