Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const (
RemoteDecodeFinishReason = "remote_decode"
)

var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
var cumulativeBucketsProbabilities []float64

// list of responses to use in random mode for comepltion requests
var chatCompletionFakeResponses = []string{
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
Expand All @@ -54,6 +57,16 @@ var chatCompletionFakeResponses = []string{
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
}

func init() {
cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
sum := 0.0

for i, val := range respLenBucketsProbabilities {
sum += val
cumulativeBucketsProbabilities[i] = sum
}
}

// returns the max tokens or error if incorrect
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
var typeToken string
Expand Down Expand Up @@ -154,14 +167,61 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
if maxCompletionTokens == nil {
numOfTokens = GetRandomResponseLen()
} else {
numOfTokens = int(*maxCompletionTokens)
finishReason = GetRandomFinishReason()
maxTokens := int(*maxCompletionTokens)
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
finishReason = LengthFinishReason
}
}

text := GetRandomText(numOfTokens)
return text, finishReason
}

// getResponseLengthByHistogram calculates length of the response based on the max tokens value and pre-defined buckets
// response length is distributed according the probabilities defined in respLenBucketsProbabilities
// the last element defines probability of reposnse with maxToken tokens
// other values define probabilities for equal sized buckets
func getResponseLengthByHistogram(maxTokens int) int {
if maxTokens <= 1 {
return maxTokens
}
if maxTokens <= len(cumulativeBucketsProbabilities) {
res := RandomInt(1, maxTokens)
return res
}

r := RandomFloat(0, 1)

// probability to return maxToken
if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] {
return maxTokens
}

// determine which bucket to use
bucketIndex := len(cumulativeBucketsProbabilities) - 1
for i, c := range cumulativeBucketsProbabilities {
if r <= c {
bucketIndex = i
break
}
}

// compute bucket ranges (maxToken is out of scope)
bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1)

start := int(bucketSize*float64(bucketIndex)) + 1
end := int(bucketSize * float64(bucketIndex+1))
if end >= maxTokens {
end = maxTokens - 1
}

// Pick uniformly within the bucket’s range
return RandomInt(start, end)
}

// GetResponseText returns response text, from a given text
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {
Expand Down
20 changes: 16 additions & 4 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() {
It("should return short text", func() {
maxCompletionTokens := int64(2)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
It("should return long text", func() {
// return required number of tokens although it is higher than ResponseLenMax
maxCompletionTokens := int64(ResponseLenMax * 5)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
Expect(IsValidText(text)).To(BeTrue())
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
})

Expand Down
Loading