Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions pkg/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const (
RemoteDecodeFinishReason = "remote_decode"
)

var randomValuesBuckets = []float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
var cumulativeBuckets []float64

// list of responses to use in random mode for comepltion requests
var chatCompletionFakeResponses = []string{
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
Expand All @@ -54,6 +57,16 @@ var chatCompletionFakeResponses = []string{
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
}

func init() {
cumulativeBuckets = make([]float64, len(randomValuesBuckets))
sum := 0.0

for i, val := range randomValuesBuckets {
sum += val
cumulativeBuckets[i] = sum
}
}

// returns the max tokens or error if incorrect
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
var typeToken string
Expand Down Expand Up @@ -154,14 +167,62 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
if maxCompletionTokens == nil {
numOfTokens = GetRandomResponseLen()
} else {
numOfTokens = int(*maxCompletionTokens)
finishReason = GetRandomFinishReason()
maxTokens := int(*maxCompletionTokens)
// max tokens is defined - generate real length of the response based on it
numOfTokens = getResponseLengthByHistogram(maxTokens)
if numOfTokens == maxTokens {
// if response should be create with maximum number of tokens - finish reason will be 'length'
finishReason = LengthFinishReason
}
}

text := GetRandomText(numOfTokens)
return text, finishReason
}

// length is distributed to 6 buckets:
// 15% - max tokens
// other values are divided to 5 additional buckets with the following probabilities starting from the bucket for one token
// 20%, 30%, 20%, 5%, 10%
func getResponseLengthByHistogram(maxTokens int) int {
if maxTokens <= 1 {
return maxTokens
}
if maxTokens <= len(cumulativeBuckets) {
res := RandomInt(1, maxTokens)
return res
}

r := RandomFloat(0, 1)

// probability to return maxToken
if r > cumulativeBuckets[len(cumulativeBuckets)-2] {
return maxTokens
}

// determine which bucket to use
bucketIndex := 0
for i, c := range cumulativeBuckets {
if r <= c {
bucketIndex = i
break
}
}

// compute bucket ranges
nonMaxCount := maxTokens - 1
bucketSize := float64(nonMaxCount) / 5.0

start := int(bucketSize*float64(bucketIndex)) + 1
end := int(bucketSize * float64(bucketIndex+1))
if end >= maxTokens {
end = maxTokens - 1
}

// Pick uniformly within the bucket’s range
return RandomInt(start, end)
}

// GetResponseText returns response text, from a given text
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {
Expand Down
20 changes: 16 additions & 4 deletions pkg/common/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() {
It("should return short text", func() {
maxCompletionTokens := int64(2)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
It("should return long text", func() {
// return required number of tokens although it is higher than ResponseLenMax
maxCompletionTokens := int64(ResponseLenMax * 5)
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
tokensCnt := int64(len(Tokenize(text)))
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
Expect(IsValidText(text)).To(BeTrue())
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
if tokensCnt == maxCompletionTokens {
Expect(finishReason).To(Equal(LengthFinishReason))
} else {
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
Expect(finishReason).To(Equal(StopFinishReason))
}
})
})

Expand Down