diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 2cb4ad66..d3ea5b44 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -39,6 +39,10 @@ const ( RemoteDecodeFinishReason = "remote_decode" ) +// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response +var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} +var cumulativeBucketsProbabilities []float64 + // list of responses to use in random mode for comepltion requests var chatCompletionFakeResponses = []string{ `Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`, @@ -54,6 +58,16 @@ var chatCompletionFakeResponses = []string{ `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, } +func init() { + cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) + sum := 0.0 + + for i, val := range respLenBucketsProbabilities { + sum += val + cumulativeBucketsProbabilities[i] = sum + } +} + // returns the max tokens or error if incorrect func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) { var typeToken string @@ -154,14 +168,67 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) { if maxCompletionTokens == nil { numOfTokens = GetRandomResponseLen() } else { - numOfTokens = int(*maxCompletionTokens) - finishReason = GetRandomFinishReason() + maxTokens := int(*maxCompletionTokens) + // max tokens is defined - generate real length of the response based on it + numOfTokens = getResponseLengthByHistogram(maxTokens) + if numOfTokens == maxTokens { + // if response should be create with maximum number of tokens - finish reason will be 'length' + finishReason = LengthFinishReason + } } text := GetRandomText(numOfTokens) return text, finishReason } +// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. +// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. +// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. +// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. +// Other values define probabilities for the equally sized buckets. +// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] +func getResponseLengthByHistogram(maxTokens int) int { + if maxTokens <= 1 { + return maxTokens + } + // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] + if maxTokens <= len(cumulativeBucketsProbabilities) { + res := RandomInt(1, maxTokens) + return res + } + + r := RandomFloat(0, 1) + + // check if r is in the last bucket, then maxTokens should be returned + if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { + return maxTokens + } + + // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use + // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 + bucketIndex := len(cumulativeBucketsProbabilities) - 1 + for i, c := range cumulativeBucketsProbabilities { + if r <= c { + bucketIndex = i + break + } + } + + // calculate the size of all of the buckets (except the special last bucket) + bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) + // start is the minimum number in the required bucket + start := int(bucketSize*float64(bucketIndex)) + 1 + // end is the maximum number in the required bucket + end := int(bucketSize * float64(bucketIndex+1)) + // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 + if end >= maxTokens { + end = maxTokens - 1 + } + + // pick uniformly within the bucket’s range + return RandomInt(start, end) +} + // GetResponseText returns response text, from a given text // considering max completion tokens if it is not nil, and a finish reason (stop or length) func GetResponseText(maxCompletionTokens *int64, text string) (string, string) { diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index dd6cadab..b8f3285e 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() { It("should return short text", func() { maxCompletionTokens := int64(2) text, finishReason := GetRandomResponseText(&maxCompletionTokens) - Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens)) - Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason)) + tokensCnt := int64(len(Tokenize(text))) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } }) It("should return long text", func() { // return required number of tokens although it is higher than ResponseLenMax maxCompletionTokens := int64(ResponseLenMax * 5) text, finishReason := GetRandomResponseText(&maxCompletionTokens) - Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens)) + tokensCnt := int64(len(Tokenize(text))) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) Expect(IsValidText(text)).To(BeTrue()) - Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason)) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } }) })