diff --git a/pkg/common/tools_utils.go b/pkg/common/tools_utils.go index e18bc924..a0fee0bf 100644 --- a/pkg/common/tools_utils.go +++ b/pkg/common/tools_utils.go @@ -92,6 +92,7 @@ func CreateToolCalls( tools []openaiserverapi.Tool, toolChoice openaiserverapi.ToolChoice, config *Configuration, + random *Random, ) ([]openaiserverapi.ToolCall, int, error) { generateCalls := func(availableTools []openaiserverapi.Tool, minCalls int) ([]openaiserverapi.ToolCall, int, error) { if len(availableTools) == 0 { @@ -102,7 +103,7 @@ func CreateToolCalls( numberOfCalls := minCalls if len(availableTools) > minCalls { // Randomly decide how many tools to call, between minCalls and the total available. - numberOfCalls = RandomInt(minCalls, len(availableTools)) + numberOfCalls = random.RandomInt(minCalls, len(availableTools)) } if numberOfCalls == 0 { @@ -114,11 +115,11 @@ func CreateToolCalls( // Randomly choose which tool to call. We may call the same tool more than once. index := 0 if len(availableTools) > 1 { - index = RandomInt(0, len(availableTools)-1) + index = random.RandomInt(0, len(availableTools)-1) } chosenTool := availableTools[index] - args, err := generateToolArguments(chosenTool, config) + args, err := generateToolArguments(chosenTool, config, random) if err != nil { return nil, 0, err } @@ -133,7 +134,7 @@ func CreateToolCalls( TokenizedArguments: Tokenize(string(argsJson)), Name: &chosenTool.Function.Name, }, - ID: "chatcmpl-tool-" + RandomNumericString(10), + ID: "chatcmpl-tool-" + random.RandomNumericString(10), Type: "function", Index: i, } @@ -188,7 +189,7 @@ func getRequiredAsMap(property map[string]any) map[string]struct{} { return required } -func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) { +func generateToolArguments(tool openaiserverapi.Tool, config *Configuration, random *Random) (map[string]any, error) { arguments := make(map[string]any) properties, _ := tool.Function.Parameters["properties"].(map[string]any) @@ -196,10 +197,10 @@ func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (ma for param, property := range properties { _, paramIsRequired := required[param] - if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) { + if !paramIsRequired && !random.RandomBool(config.ToolCallNotRequiredParamProbability) { continue } - arg, err := createArgument(property, config) + arg, err := createArgument(property, config, random) if err != nil { return nil, err } @@ -209,7 +210,7 @@ func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (ma return arguments, nil } -func createArgument(property any, config *Configuration) (any, error) { +func createArgument(property any, config *Configuration, random *Random) (any, error) { propertyMap, _ := property.(map[string]any) paramType := propertyMap["type"] @@ -218,20 +219,20 @@ func createArgument(property any, config *Configuration) (any, error) { if ok { enumArray, ok := enum.([]any) if ok && len(enumArray) > 0 { - index := RandomInt(0, len(enumArray)-1) + index := random.RandomInt(0, len(enumArray)-1) return enumArray[index], nil } } switch paramType { case "string": - return getStringArgument(), nil + return getStringArgument(random), nil case "integer": - return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil + return random.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil case "number": - return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil + return random.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil case "boolean": - return FlipCoin(), nil + return random.FlipCoin(), nil case "array": items := propertyMap["items"] itemsMap := items.(map[string]any) @@ -246,10 +247,10 @@ func createArgument(property any, config *Configuration) (any, error) { if minItems > maxItems { return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems) } - numberOfElements := RandomInt(minItems, maxItems) + numberOfElements := random.RandomInt(minItems, maxItems) array := make([]any, numberOfElements) for i := range numberOfElements { - elem, err := createArgument(itemsMap, config) + elem, err := createArgument(itemsMap, config, random) if err != nil { return nil, err } @@ -262,10 +263,10 @@ func createArgument(property any, config *Configuration) (any, error) { object := make(map[string]interface{}) for fieldName, fieldProperties := range objectProperties { _, fieldIsRequired := required[fieldName] - if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) { + if !fieldIsRequired && !random.RandomBool(config.ObjectToolCallNotRequiredParamProbability) { continue } - fieldValue, err := createArgument(fieldProperties, config) + fieldValue, err := createArgument(fieldProperties, config, random) if err != nil { return nil, err } @@ -277,8 +278,8 @@ func createArgument(property any, config *Configuration) (any, error) { } } -func getStringArgument() string { - index := RandomInt(0, len(fakeStringArguments)-1) +func getStringArgument(random *Random) string { + index := random.RandomInt(0, len(fakeStringArguments)-1) return fakeStringArguments[index] } diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 7050fc55..68d731a4 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -49,63 +49,66 @@ func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode return isValid, completionTokens, totalTokens } -func RandomNumericString(length int) string { - digits := "0123456789" - result := make([]byte, length) - for i := 0; i < length; i++ { - num := RandomInt(0, 9) - result[i] = digits[num] - } - return string(result) +type Random struct { + randomGenerator *rand.Rand + randMutex sync.Mutex } -var randomGenerator *rand.Rand -var randMutex sync.Mutex - -func InitRandom(seed int64) { +func NewRandom(seed int64) *Random { src := rand.NewSource(seed) - randomGenerator = rand.New(src) - uuid.SetRand(randomGenerator) + randomGenerator := rand.New(src) + uuid.SetRand(rand.New(rand.NewSource(seed))) + return &Random{randomGenerator: randomGenerator} } // Returns an integer between min and max (included) -func RandomInt(min int, max int) int { - randMutex.Lock() - defer randMutex.Unlock() - return randomGenerator.Intn(max-min+1) + min +func (r *Random) RandomInt(min int, max int) int { + r.randMutex.Lock() + defer r.randMutex.Unlock() + + return r.randomGenerator.Intn(max-min+1) + min } // Returns true or false randomly -func FlipCoin() bool { - return RandomInt(0, 1) != 0 +func (r *Random) FlipCoin() bool { + return r.RandomInt(0, 1) != 0 } // probability is an integer between 0 and 100 -func RandomBool(probability int) bool { - randMutex.Lock() - defer randMutex.Unlock() - return randomGenerator.Float64() < float64(probability)/100 +func (r *Random) RandomBool(probability int) bool { + r.randMutex.Lock() + defer r.randMutex.Unlock() + + return r.randomGenerator.Float64() < float64(probability)/100 } // Returns a random float64 in the range [min, max) -func RandomFloat(min float64, max float64) float64 { - randMutex.Lock() - defer randMutex.Unlock() - return randomGenerator.Float64()*(max-min) + min +func (r *Random) RandomFloat(min float64, max float64) float64 { + r.randMutex.Lock() + defer r.randMutex.Unlock() + + return r.randomGenerator.Float64()*(max-min) + min } -// Returns a normally distributed int -// If the generated value differs by more than 70% from mean, the returned -// value will be 70% of mean -func RandomNorm(mean int, stddev int) int { +// Returns a normally distributed float64 +func (r *Random) RandomNorm(mean int, stddev int) float64 { if stddev == 0 { - return mean + return float64(mean) } - randMutex.Lock() - defer randMutex.Unlock() + r.randMutex.Lock() + defer r.randMutex.Unlock() + mean_ := float64(mean) stddev_ := float64(stddev) - value := randomGenerator.NormFloat64()*stddev_ + mean_ + return r.randomGenerator.NormFloat64()*stddev_ + mean_ +} + +// Returns a normally distributed int +// If the generated value differs by more than 70% from mean, the returned +// value will be 70% of mean +func (r *Random) RandomNormTruncated(mean int, stddev int) int { + value := r.RandomNorm(mean, stddev) + mean_ := float64(mean) if value < 0.3*mean_ { value = 0.3 * mean_ } else if value > 1.7*mean_ { @@ -115,12 +118,22 @@ func RandomNorm(mean int, stddev int) int { } // GenerateUUIDString generates a UUID string under a lock -func GenerateUUIDString() string { - randMutex.Lock() - defer randMutex.Unlock() +func (r *Random) GenerateUUIDString() string { + r.randMutex.Lock() + defer r.randMutex.Unlock() return uuid.NewString() } +func (r *Random) RandomNumericString(length int) string { + digits := "0123456789" + result := make([]byte, length) + for i := 0; i < length; i++ { + num := r.RandomInt(0, 9) + result[i] = digits[num] + } + return string(result) +} + // Regular expression for the response tokenization var re *regexp.Regexp diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index 9a0af043..5db3d070 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -17,17 +17,11 @@ limitations under the License. package common import ( - "time" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Utils", Ordered, func() { - BeforeAll(func() { - InitRandom(time.Now().UnixNano()) - }) - Context("validateContextWindow", func() { It("should pass when total tokens are within limit", func() { promptTokens := 100 diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 69b2f591..34f282eb 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -435,23 +435,23 @@ func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string { } // GetTokens returns tokens and finishReason for the given request and mode (echo or random) -func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { +func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) { if mode == common.ModeEcho { return d.echo(req) } - nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS()) - tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason) + nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random) + tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason, random) return tokens, finishReason, err } -func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { +func (d *CustomDataset) query(query string, nTokens int, random *common.Random) ([][]string, error) { rows, err := d.db.Query(query) if err != nil { if !d.hasWarned { d.logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") d.hasWarned = true } - return [][]string{GenPresetRandomTokens(nTokens)}, nil + return [][]string{GenPresetRandomTokens(random, nTokens)}, nil } defer func() { if cerr := rows.Close(); cerr != nil { @@ -461,12 +461,13 @@ func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { return unmarshalAllRecords(rows) } -func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) { +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string, + random *common.Random) ([]string, error) { // query by prompt hash first promptHash := d.GetPromptHash(req) promptHashHex := d.GetPromptHashHex(promptHash) query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" - tokensList, err := d.query(query, nTokens) + tokensList, err := d.query(query, nTokens, random) // filter out results according to finish reason var filteredTokensList [][]string @@ -486,20 +487,20 @@ func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nT switch finishReason { case LengthFinishReason: query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" - tokensList, err = d.query(query, nTokens) + tokensList, err = d.query(query, nTokens, random) case StopFinishReason: query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";" - tokensList, err = d.query(query, nTokens) + tokensList, err = d.query(query, nTokens, random) } } if err != nil || len(tokensList) == 0 { // if both queries fail or return no results, generate random tokens - return GenPresetRandomTokens(nTokens), nil + return GenPresetRandomTokens(random, nTokens), nil } if d.hasWarned { d.hasWarned = false } - randIndex := common.RandomInt(0, len(tokensList)-1) + randIndex := random.RandomInt(0, len(tokensList)-1) return tokensList[randIndex], nil } diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index afd734a2..f406be36 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -46,10 +46,11 @@ var _ = Describe("CustomDataset", Ordered, func() { pathToInvalidTableDB string pathToInvalidColumnDB string pathToInvalidTypeDB string + random *common.Random ) BeforeAll(func() { - common.InitRandom(time.Now().UnixNano()) + random = common.NewRandom(time.Now().UnixNano()) }) BeforeEach(func() { @@ -182,7 +183,7 @@ var _ = Describe("CustomDataset", Ordered, func() { req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).NotTo(HaveOccurred()) Expect(finishReason).To(Equal(StopFinishReason)) Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) @@ -196,7 +197,7 @@ var _ = Describe("CustomDataset", Ordered, func() { Prompt: testPrompt, MaxTokens: &n, } - tokens, _, err := dataset.GetTokens(req, common.ModeRandom) + tokens, _, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).NotTo(HaveOccurred()) Expect(len(tokens)).To(BeNumerically("<=", 2)) }) @@ -208,7 +209,7 @@ var _ = Describe("CustomDataset", Ordered, func() { req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).NotTo(HaveOccurred()) Expect(finishReason).To(Equal(StopFinishReason)) Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 5d5d2fe9..15c737ab 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -19,7 +19,6 @@ package dataset import ( "context" "math" - "math/rand" "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -74,7 +73,7 @@ type Dataset interface { // Close closes the dataset Close() error // GetTokens returns tokens for the given request and mode (echo or random) - GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) + GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) } func init() { @@ -89,9 +88,9 @@ func init() { // GetRandomResponseLen returns int in range [1, responseLenMax] // numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev -func GetRandomResponseLen() int { +func GetRandomResponseLen(random *common.Random) int { for { - val := rand.NormFloat64()*responseLenStddev + responseLenMean + val := random.RandomNorm(responseLenMean, responseLenStddev) if val >= 1 && val <= ResponseLenMax { return int(math.Round(val)) } @@ -100,8 +99,8 @@ func GetRandomResponseLen() int { } // GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability -func GetRandomFinishReason() string { - if rand.Float64() < stopFinishReasonProbability { +func GetRandomFinishReason(random *common.Random) string { + if random.RandomFloat(0, 1) < stopFinishReasonProbability { return StopFinishReason } return LengthFinishReason @@ -111,11 +110,11 @@ func GetRandomFinishReason() string { // select randomly a sentence from chatCompletionFakeResponses, // if number of tokens is lower than required - select another sentence, // continue until the required number of tokens is achieved -func GenPresetRandomTokens(numOfTokens int) []string { +func GenPresetRandomTokens(random *common.Random, numOfTokens int) []string { allTokens := make([]string, 0) for len(allTokens) < numOfTokens { - index := common.RandomInt(0, len(chatCompletionFakeResponses)-1) + index := random.RandomInt(0, len(chatCompletionFakeResponses)-1) // create tokens from text, splitting by spaces and special characters tokens := common.Tokenize(chatCompletionFakeResponses[index]) remaining := numOfTokens - len(allTokens) @@ -146,13 +145,13 @@ func GenPresetRandomTokens(numOfTokens int) []string { // - finish reason is stop // if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens // - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool) (int, string) { +func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool, random *common.Random) (int, string) { numOfTokens := 0 finishReason := StopFinishReason // no max completion tokens, return text with random length if maxCompletionTokens == nil { - numOfTokens = GetRandomResponseLen() + numOfTokens = GetRandomResponseLen(random) } else { maxTokens := int(*maxCompletionTokens) if ignore_eos { @@ -160,7 +159,7 @@ func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool) (int, strin finishReason = LengthFinishReason } else { // max tokens is defined - generate real length of the response based on it - numOfTokens = getResponseLengthByHistogram(maxTokens) + numOfTokens = getResponseLengthByHistogram(random, maxTokens) if numOfTokens == maxTokens { // if response should be create with maximum number of tokens - finish reason will be 'length' finishReason = LengthFinishReason @@ -177,17 +176,17 @@ func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool) (int, strin // The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. // Other values define probabilities for the equally sized buckets. // If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] -func getResponseLengthByHistogram(maxTokens int) int { +func getResponseLengthByHistogram(random *common.Random, maxTokens int) int { if maxTokens <= 1 { return maxTokens } // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] if maxTokens <= len(cumulativeBucketsProbabilities) { - res := common.RandomInt(1, maxTokens) + res := random.RandomInt(1, maxTokens) return res } - r := common.RandomFloat(0, 1) + r := random.RandomFloat(0, 1) // check if r is in the last bucket, then maxTokens should be returned if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { @@ -208,7 +207,7 @@ func getResponseLengthByHistogram(maxTokens int) int { start, end := calcBucketBoundaries(maxTokens, bucketIndex) // pick uniformly within the bucket’s range - return common.RandomInt(start, end) + return random.RandomInt(start, end) } // calcBucketBoundaries calculates boundaries of a bucket with the given index. @@ -295,10 +294,11 @@ func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, str } // GetTokens returns tokens and finishReason for the given request and mode (echo or random) -func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { +func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, + random *common.Random) ([]string, string, error) { if mode == common.ModeEcho { return d.echo(req) } - nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS()) - return GenPresetRandomTokens(nTokensToGen), finishReason, nil + nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random) + return GenPresetRandomTokens(random, nTokensToGen), finishReason, nil } diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index 1321e9e9..2e978fc5 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -30,10 +30,11 @@ import ( var _ = Describe("Dataset", Ordered, func() { var ( dataset *BaseDataset + random *common.Random ) BeforeAll(func() { - common.InitRandom(time.Now().UnixNano()) + random = common.NewRandom(time.Now().UnixNano()) }) BeforeEach(func() { @@ -44,7 +45,7 @@ var _ = Describe("Dataset", Ordered, func() { It("should return complete text", func() { req := &openaiserverapi.ChatCompletionRequest{} - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).ShouldNot(HaveOccurred()) text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) @@ -56,7 +57,7 @@ var _ = Describe("Dataset", Ordered, func() { req := &openaiserverapi.ChatCompletionRequest{ MaxCompletionTokens: &maxCompletionTokens, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).ShouldNot(HaveOccurred()) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) @@ -74,7 +75,7 @@ var _ = Describe("Dataset", Ordered, func() { req := &openaiserverapi.ChatCompletionRequest{ MaxTokens: &maxCompletionTokens, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).ShouldNot(HaveOccurred()) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) @@ -95,7 +96,7 @@ var _ = Describe("Dataset", Ordered, func() { MaxTokens: &n, } req.SetIgnoreEOS(true) - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) Expect(err).ShouldNot(HaveOccurred()) nGenTokens := int64(len(tokens)) Expect(nGenTokens).Should(Equal(n)) @@ -140,7 +141,7 @@ var _ = Describe("Dataset", Ordered, func() { for _, len := range lenArr { name := fmt.Sprintf("should return text with %d tokens", len) It(name, func() { - tokens := GenPresetRandomTokens(len) + tokens := GenPresetRandomTokens(random, len) Expect(tokens).Should(HaveLen(len)) }) } diff --git a/pkg/kv-cache/kv_cache_test.go b/pkg/kv-cache/kv_cache_test.go index f7e09e00..ebfbc32e 100644 --- a/pkg/kv-cache/kv_cache_test.go +++ b/pkg/kv-cache/kv_cache_test.go @@ -121,7 +121,7 @@ type threadTestCase struct { } var _ = Describe("KV cache", Ordered, func() { - common.InitRandom(time.Now().UnixNano()) + random := common.NewRandom(time.Now().UnixNano()) Context("general tests", func() { // check single request processing, ensure cache is valid after request processing started @@ -434,7 +434,8 @@ var _ = Describe("KV cache", Ordered, func() { for j := range testCase.numOperations { reqID := fmt.Sprintf("req_%d_%d", id, j) - blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, testCase.maxHashValue) + blocks := createRandomArray(testCase.minBlockLen, testCase.maxBlockLen, + testCase.maxHashValue, random) _, err := blockCache.startRequest(reqID, blocks) if err != nil { @@ -443,7 +444,7 @@ var _ = Describe("KV cache", Ordered, func() { continue } - time.Sleep(time.Duration(common.RandomInt(1, 100)) * time.Microsecond) + time.Sleep(time.Duration(random.RandomInt(1, 100)) * time.Microsecond) err = blockCache.finishRequest(reqID) Expect(err).NotTo(HaveOccurred()) @@ -465,16 +466,16 @@ var _ = Describe("KV cache", Ordered, func() { }) }) -func createRandomArray(minArrLen, maxArrLen int, maxValue uint64) []uint64 { +func createRandomArray(minArrLen, maxArrLen int, maxValue uint64, random *common.Random) []uint64 { // Random length between a and b (inclusive) - length := common.RandomInt(minArrLen, maxArrLen) + length := random.RandomInt(minArrLen, maxArrLen) // Create array with random values arr := make([]uint64, 0) seen := make(map[uint64]struct{}) for len(arr) < length { - val := uint64(common.RandomInt(0, int(maxValue))) + val := uint64(random.RandomInt(0, int(maxValue))) if _, exists := seen[val]; !exists { seen[val] = struct{}{} arr = append(arr, val) diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go index 69daf36e..cbabe621 100644 --- a/pkg/llm-d-inference-sim/failures.go +++ b/pkg/llm-d-inference-sim/failures.go @@ -44,16 +44,16 @@ var predefinedFailures = map[string]openaiserverapi.CompletionError{ } // shouldInjectFailure determines whether to inject a failure based on configuration -func shouldInjectFailure(config *common.Configuration) bool { +func shouldInjectFailure(config *common.Configuration, random *common.Random) bool { if config.FailureInjectionRate == 0 { return false } - return common.RandomInt(1, 100) <= config.FailureInjectionRate + return random.RandomInt(1, 100) <= config.FailureInjectionRate } // getRandomFailure returns a random failure from configured types or all types if none specified -func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionError { +func getRandomFailure(config *common.Configuration, random *common.Random) openaiserverapi.CompletionError { var availableFailures []string if len(config.FailureTypes) == 0 { // Use all failure types if none specified @@ -69,7 +69,7 @@ func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionEr return predefinedFailures[common.FailureTypeServerError] } - randomIndex := common.RandomInt(0, len(availableFailures)-1) + randomIndex := random.RandomInt(0, len(availableFailures)-1) randomType := availableFailures[randomIndex] // Customize message with current model name diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 1459eed5..2a2756f3 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -33,8 +33,9 @@ import ( var _ = Describe("Failures", func() { Describe("getRandomFailure", Ordered, func() { + var random *common.Random BeforeAll(func() { - common.InitRandom(time.Now().UnixNano()) + random = common.NewRandom(time.Now().UnixNano()) }) It("should return a failure from all types when none specified", func() { @@ -42,7 +43,7 @@ var _ = Describe("Failures", func() { Model: "test-model", FailureTypes: []string{}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(BeNumerically(">=", 400)) Expect(failure.Message).ToNot(BeEmpty()) Expect(failure.Type).ToNot(BeEmpty()) @@ -53,7 +54,7 @@ var _ = Describe("Failures", func() { Model: "test-model", FailureTypes: []string{common.FailureTypeRateLimit}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(Equal(429)) Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) @@ -63,7 +64,7 @@ var _ = Describe("Failures", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeInvalidAPIKey}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(Equal(401)) Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(401))) Expect(failure.Message).To(Equal("Incorrect API key provided.")) @@ -73,7 +74,7 @@ var _ = Describe("Failures", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeContextLength}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(Equal(400)) Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(400))) Expect(failure.Param).ToNot(BeNil()) @@ -84,7 +85,7 @@ var _ = Describe("Failures", func() { config := &common.Configuration{ FailureTypes: []string{common.FailureTypeServerError}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(Equal(503)) Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(503))) }) @@ -94,7 +95,7 @@ var _ = Describe("Failures", func() { Model: "test-model", FailureTypes: []string{common.FailureTypeModelNotFound}, } - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(Equal(404)) Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(404))) Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) @@ -105,7 +106,7 @@ var _ = Describe("Failures", func() { FailureTypes: []string{}, } // This test is probabilistic since it randomly selects, but we can test structure - failure := getRandomFailure(config) + failure := getRandomFailure(config, random) Expect(failure.Code).To(BeNumerically(">=", 400)) Expect(failure.Type).ToNot(BeEmpty()) }) diff --git a/pkg/llm-d-inference-sim/latencies.go b/pkg/llm-d-inference-sim/latencies.go index b98b61e5..1a1e4981 100644 --- a/pkg/llm-d-inference-sim/latencies.go +++ b/pkg/llm-d-inference-sim/latencies.go @@ -17,8 +17,6 @@ limitations under the License. // Package vllmsim implements the vLLM simulator. package llmdinferencesim -import "github.com/llm-d/llm-d-inference-sim/pkg/common" - func (s *VllmSimulator) getCurrLoadFactor() float64 { if s.config.MaxNumSeqs <= 1 { return 1.0 @@ -44,22 +42,22 @@ func (s *VllmSimulator) getWaitTimeToFirstToken(nPromptTokens int, nCachedPrompt if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 { // is disaggregated PD and ttft is calculated using number of prompt tokens kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens - return common.RandomNorm(kvCacheTransT, s.config.KVCacheTransferTimeStdDev) + return s.random.RandomNormTruncated(kvCacheTransT, s.config.KVCacheTransferTimeStdDev) } // is disaggregated PD and *not* using number of prompt tokens - return common.RandomNorm(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev) + return s.random.RandomNormTruncated(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev) } if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 { // is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache prefillTime := s.getPrefillOverhead() + (nPromptTokens-nCachedPromptTokens)*s.getPrefillTimePerToken() - return common.RandomNorm(prefillTime, s.config.PrefillTimeStdDev) + return s.random.RandomNormTruncated(prefillTime, s.config.PrefillTimeStdDev) } // is aggregated PD and *not* using number of prompt tokens - return common.RandomNorm(s.getTimeToFirstToken(), s.config.TimeToFirstTokenStdDev) + return s.random.RandomNormTruncated(s.getTimeToFirstToken(), s.config.TimeToFirstTokenStdDev) } // returns inter token latency func (s *VllmSimulator) getInterTokenLatency() int { latency := int(float64(s.config.InterTokenLatency) * s.getCurrLoadFactor()) - return common.RandomNorm(latency, s.config.InterTokenLatencyStdDev) + return s.random.RandomNormTruncated(latency, s.config.InterTokenLatencyStdDev) } diff --git a/pkg/llm-d-inference-sim/latencies_test.go b/pkg/llm-d-inference-sim/latencies_test.go index 8c20450e..97435763 100644 --- a/pkg/llm-d-inference-sim/latencies_test.go +++ b/pkg/llm-d-inference-sim/latencies_test.go @@ -43,7 +43,7 @@ var _ = Describe("Check random latencies", Ordered, func() { simulator.metrics.runReqChan = make(chan int64, 100) - common.InitRandom(time.Now().UnixNano()) + simulator.random = common.NewRandom(time.Now().UnixNano()) }) DescribeTable("should calculate inter token latency correctly", diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 8247ce33..e21717a9 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -107,7 +107,7 @@ func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) // readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { - requestID := common.GenerateUUIDString() + requestID := s.random.GenerateUUIDString() if isChatCompletion { var req openaiserverapi.ChatCompletionRequest diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 7f9bc249..89bec6c7 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -184,6 +184,8 @@ type VllmSimulator struct { metrics metricsData // loras contains information about which LoRAs are in use loras *lorasUsageInfo + // rand with a configurable seed to generate reproducible random responses + random *common.Random // a channel for free workers freeWorkers chan *worker @@ -279,7 +281,7 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { } func (s *VllmSimulator) initializeSim(ctx context.Context) error { - common.InitRandom(s.config.Seed) + s.random = common.NewRandom(s.config.Seed) for _, lora := range s.config.LoraModules { s.loraAdaptors.Store(lora.Name, "") @@ -496,8 +498,8 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple }() // Check if we should inject a failure - if shouldInjectFailure(s.config) { - failure := getRandomFailure(s.config) + if shouldInjectFailure(s.config, s.random) { + failure := getRandomFailure(s.config, s.random) s.sendCompletionError(ctx, failure, true) return } @@ -557,7 +559,7 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { - baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), time.Now().Unix(), modelName, usageData) if doRemoteDecode { diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 6be9a43e..7f220043 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -171,7 +171,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response, // supports both modes (text and chat) func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), context.creationTime, context.model, usageData) if context.isChatCompletion { @@ -186,7 +186,7 @@ func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *o // createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response, // for text completion func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), context.creationTime, context.model, nil) baseChunk.Object = textCompletionObject return openaiserverapi.CreateTextCompletionResponse(baseChunk, @@ -198,7 +198,7 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok // API response, for chat completion. It sets either role, or token, or tool call info in the message. func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall, role string, finishReason *string) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), context.creationTime, context.model, nil) baseChunk.Object = chatCompletionChunkObject chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk, diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index 674c283f..c1fe3719 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -97,13 +97,13 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) !common.IsToolChoiceNone(req.GetToolChoice()) && req.GetTools() != nil { toolCalls, completionTokens, err = - common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) + common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config, s.random) finishReason = dataset.ToolsFinishReason } if toolCalls == nil && err == nil { // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. - responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) + responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode, s.random) completionTokens += len(responseTokens) } if err != nil {