Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ const (
RemoteDecodeFinishReason = "remote_decode"
)

// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response
var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
var cumulativeBucketsProbabilities []float64

// list of responses to use in random mode for comepltion requests
var chatCompletionFakeResponses = []string{
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
Expand All @@ -54,6 +58,16 @@ var chatCompletionFakeResponses = []string{
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
}

func init() {
cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
sum := 0.0

for i, val := range respLenBucketsProbabilities {
sum += val
cumulativeBucketsProbabilities[i] = sum
}
}

// returns the max tokens or error if incorrect
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
var typeToken string
Expand Down Expand Up @@ -154,14 +168,67 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
if maxCompletionTokens == nil {
numOfTokens = GetRandomResponseLen()
} else {
numOfTokens = int(*maxCompletionTokens)
finishReason = GetRandomFinishReason()
maxTokens := int(*maxCompletionTokens)
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
finishReason = LengthFinishReason
}
}

text := GetRandomText(numOfTokens)
return text, finishReason
}

// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities.
// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value.
// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens.
// Other values define probabilities for the equally sized buckets.
// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens]
func getResponseLengthByHistogram(maxTokens int) int {
if maxTokens <= 1 {
return maxTokens
}
// maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens]
if maxTokens <= len(cumulativeBucketsProbabilities) {
res := RandomInt(1, maxTokens)
return res
}

r := RandomFloat(0, 1)

// check if r is in the last bucket, then maxTokens should be returned
if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] {
return maxTokens
}

// determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use
// initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1
bucketIndex := len(cumulativeBucketsProbabilities) - 1
for i, c := range cumulativeBucketsProbabilities {
if r <= c {
bucketIndex = i
break
}
}

// calculate the size of all of the buckets (except the special last bucket)
bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1)
// start is the minimum number in the required bucket
start := int(bucketSize*float64(bucketIndex)) + 1
// end is the maximum number in the required bucket
end := int(bucketSize * float64(bucketIndex+1))
// sometimes end could be maxTokens because of rounding, change the value to maxToken-1
if end >= maxTokens {
end = maxTokens - 1
}

// pick uniformly within the bucket’s range
return RandomInt(start, end)
}

// GetResponseText returns response text, from a given text
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {
Expand Down
20 changes: 16 additions & 4 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() {
It("should return short text", func() {
maxCompletionTokens := int64(2)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
It("should return long text", func() {
// return required number of tokens although it is higher than ResponseLenMax
maxCompletionTokens := int64(ResponseLenMax * 5)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
Expect(IsValidText(text)).To(BeTrue())
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
})

Expand Down
Loading