diff --git a/.gitignore b/.gitignore index 950b0cb4..d684889b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,7 @@ vendor .DS_Store *.test manifests/dev-config.yaml +pkg/dataset/.llm-d +pkg/llm-d-inference-sim/tests-tmp/ +pkg/llm-d-inference-sim/.llm-d/ +.llm-d/ diff --git a/Makefile b/Makefile index 40392e9a..a71b78c6 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ format: ## Format Go source files test: $(GINKGO) download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" ifdef GINKGO_FOCUS - CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r --focus="$(GINKGO_FOCUS)" + CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)" else CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r endif diff --git a/README.md b/README.md index 8e63b793..fa4dfde2 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,20 @@ For more details see the .`, - `Testing, testing 1,2,3.`, - `I am fine, how are you today?`, - `I am your AI assistant, how can I help you today?`, - `Today is a nice sunny day.`, - `The temperature here is twenty-five degrees centigrade.`, - `Today it is partially cloudy and raining.`, - `To be or not to be that is the question.`, - `Alas, poor Yorick! I knew him, Horatio: A fellow of infinite jest`, - `The rest is silence. `, - `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, -} - -func init() { - cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) - sum := 0.0 - - for i, val := range respLenBucketsProbabilities { - sum += val - cumulativeBucketsProbabilities[i] = sum - } -} - -// returns the max tokens or error if incorrect -func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) { - var typeToken string - var tokens *int64 - // if both arguments are passed, - // use maxCompletionTokens - // as in the real vllm - if maxCompletionTokens != nil { - tokens = maxCompletionTokens - typeToken = "max_completion_tokens" - } else if maxTokens != nil { - tokens = maxTokens - typeToken = "max_tokens" - } - if tokens != nil && *tokens < 1 { - return nil, fmt.Errorf("%s must be at least 1, got %d", typeToken, *tokens) - } - return tokens, nil -} - // ValidateContextWindow checks if the request fits within the model's context window // Returns validation result, actual completion tokens, and total tokens func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) { @@ -107,200 +38,6 @@ func ValidateContextWindow(promptTokens int, maxCompletionTokens *int64, maxMode return isValid, completionTokens, totalTokens } -// GetRandomResponseLen returns int in range [1, responseLenMax] -// numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev -func GetRandomResponseLen() int { - for { - val := rand.NormFloat64()*responseLenStddev + responseLenMean - if val >= 1 && val <= ResponseLenMax { - return int(math.Round(val)) - } - // else reject and resample - } -} - -// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability -func GetRandomFinishReason() string { - if rand.Float64() < stopFinishReasonProbability { - return StopFinishReason - } - return LengthFinishReason -} - -// GetRandomText generates random text for the required number of tokens, -// select randomly a sentence from chatCompletionFakeResponses, -// if number of tokens is lower than required - select another sentence, -// continue until the required number of tokens is achieved -func GetRandomText(numOfTokens int) string { - allTokens := make([]string, 0) - - for len(allTokens) < numOfTokens { - index := RandomInt(0, len(chatCompletionFakeResponses)-1) - // create tokens from text, splitting by spaces and special characters - tokens := Tokenize(chatCompletionFakeResponses[index]) - remaining := numOfTokens - len(allTokens) - - if len(tokens) > remaining { - // there is too many tokens, append only the relevant part - tokens = tokens[:remaining] - } - - if len(allTokens) > 0 { - // for not first sentences add space to the first token to separate between sentences without adding an additional token - tokens[0] = " " + tokens[0] - } - - allTokens = append(allTokens, tokens...) - } - - // return all tokens as text - return strings.Join(allTokens, "") -} - -// GetRandomResponseText generates text to be returned in a response, and the finish reason (stop or length) -// if maxCompletionTokens is defined -// - currently, the generated number of words in the text will be equal to it value -// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests -// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length -// if maxCompletionTokens is nil -// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters -// - finish reason is stop -// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens -// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func GetRandomResponseText(maxCompletionTokens *int64, ignore_eos bool) (string, string) { - numOfTokens := 0 - finishReason := StopFinishReason - - // no max completion tokens, return text with random length - if maxCompletionTokens == nil { - numOfTokens = GetRandomResponseLen() - } else { - maxTokens := int(*maxCompletionTokens) - if ignore_eos { - numOfTokens = maxTokens - finishReason = LengthFinishReason - } else { - // max tokens is defined - generate real length of the response based on it - numOfTokens = getResponseLengthByHistogram(maxTokens) - if numOfTokens == maxTokens { - // if response should be create with maximum number of tokens - finish reason will be 'length' - finishReason = LengthFinishReason - } - } - } - - text := GetRandomText(numOfTokens) - return text, finishReason -} - -// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. -// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. -// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. -// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. -// Other values define probabilities for the equally sized buckets. -// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] -func getResponseLengthByHistogram(maxTokens int) int { - if maxTokens <= 1 { - return maxTokens - } - // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] - if maxTokens <= len(cumulativeBucketsProbabilities) { - res := RandomInt(1, maxTokens) - return res - } - - r := RandomFloat(0, 1) - - // check if r is in the last bucket, then maxTokens should be returned - if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { - return maxTokens - } - - // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use - // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 - bucketIndex := len(cumulativeBucketsProbabilities) - 1 - for i, c := range cumulativeBucketsProbabilities { - if r <= c { - bucketIndex = i - break - } - } - - // calculate the size of all of the buckets (except the special last bucket) - start, end := calcBucketBoundaries(maxTokens, bucketIndex) - - // pick uniformly within the bucket’s range - return RandomInt(start, end) -} - -// calcBucketBoundaries calculates boundaries of a bucket with the given index. -// Maximum size for equally sized buckets is defined by maxFixedBucketSize. -// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which -// division to equally size buckets will give buckets with size maxFixedBucketSize. -// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, -// all buckets will be of equal size, except the last bucket, which contains only one value. -// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], -// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. -// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), -// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. -func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { - maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1 - - if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 { - // create equally size buckets - // calculate the size of all of the buckets (except the special last bucket) - bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) - start = int(bucketSize*float64(bucketIndex)) + 1 - end = int(bucketSize * float64(bucketIndex+1)) - } else { - // create non-equally sized buckets and find boundaries of the required bucket - if bucketIndex < flexBucketIndex { - // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) - // start is the minimum number in the required bucket - start = maxFixedBucketSize*bucketIndex + 1 - end = maxFixedBucketSize * (bucketIndex + 1) - } else { - flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2)) - - if bucketIndex == flexBucketIndex { - // the relevant bucket is the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } else { - // the relevant bucket is one of buckets after the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } - } - } - - // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 - if end >= maxTokens { - end = maxTokens - 1 - } - - return start, end -} - -// GetResponseText returns response text, from a given text -// considering max completion tokens if it is not nil, and a finish reason (stop or length) -func GetResponseText(maxCompletionTokens *int64, text string) (string, string) { - // no max completion tokens, return entire text - if maxCompletionTokens == nil { - return text, StopFinishReason - } - - // create tokens from text, splitting by spaces - tokens := Tokenize(text) - - // return entire text - if *maxCompletionTokens >= int64(len(tokens)) { - return text, StopFinishReason - } - // return truncated text - return strings.Join(tokens[0:*maxCompletionTokens], " "), LengthFinishReason -} - func RandomNumericString(length int) string { digits := "0123456789" result := make([]byte, length) diff --git a/pkg/common/utils_test.go b/pkg/common/utils_test.go index d847df35..9a0af043 100644 --- a/pkg/common/utils_test.go +++ b/pkg/common/utils_test.go @@ -17,7 +17,6 @@ limitations under the License. package common import ( - "fmt" "time" . "github.com/onsi/ginkgo/v2" @@ -29,79 +28,6 @@ var _ = Describe("Utils", Ordered, func() { InitRandom(time.Now().UnixNano()) }) - Context("GetRandomResponseText", func() { - It("should return complete text", func() { - text, finishReason := GetRandomResponseText(nil, false) - Expect(IsValidText(text)).To(BeTrue()) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return short text", func() { - maxCompletionTokens := int64(2) - text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) - tokensCnt := int64(len(Tokenize(text))) - Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) - if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) - } else { - Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) - } - }) - It("should return long text", func() { - // return required number of tokens although it is higher than ResponseLenMax - maxCompletionTokens := int64(ResponseLenMax * 5) - text, finishReason := GetRandomResponseText(&maxCompletionTokens, false) - tokensCnt := int64(len(Tokenize(text))) - Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) - Expect(IsValidText(text)).To(BeTrue()) - if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) - } else { - Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) - } - }) - - DescribeTable("should return exact num of tokens", - func(maxCompletionTokens int) { - n := int64(maxCompletionTokens) - text, finishReason := GetRandomResponseText(&n, true) - nGenTokens := int64(len(Tokenize(text))) - Expect(nGenTokens).Should(Equal(n)) - Expect(finishReason).To(Equal(LengthFinishReason)) - }, - func(maxCompletionTokens int) string { - return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) - }, - Entry("1", 1), - Entry("42", 42), - Entry("99", 99), - Entry("10000", 10000), - ) - }) - - Context("GetResponseText", func() { - theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" - - It("should return the same text since max tokens is not defined", func() { - text, finishReason := GetResponseText(nil, theText) - Expect(text).Should(Equal(theText)) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return the same text since max tokens is higher than the text length", func() { - maxCompletionTokens := int64(1000) - text, finishReason := GetResponseText(&maxCompletionTokens, theText) - Expect(text).Should(Equal(theText)) - Expect(finishReason).Should(Equal(StopFinishReason)) - }) - It("should return partial text", func() { - maxCompletionTokens := int64(2) - text, finishReason := GetResponseText(&maxCompletionTokens, theText) - Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens)) - Expect(finishReason).Should(Equal(LengthFinishReason)) - }) - }) - Context("validateContextWindow", func() { It("should pass when total tokens are within limit", func() { promptTokens := 100 @@ -146,69 +72,4 @@ var _ = Describe("Utils", Ordered, func() { }) }) - Context("GetRandomText", func() { - lenArr := []int{5, 20, 50, 150} - - for _, len := range lenArr { - name := fmt.Sprintf("should return text with %d tokens", len) - It(name, func() { - text := GetRandomText(len) - Expect(Tokenize(text)).Should(HaveLen(len)) - }) - } - }) - - Context("IsValidText", func() { - validTxts := make([]string, 0) - invalidTxts := make([]string, 0) - - validTxts = append(validTxts, chatCompletionFakeResponses[0][:4]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) - - invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ") - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) - - for _, txt := range validTxts { - It("text should be valid", func() { - Expect(IsValidText(txt)).To(BeTrue()) - }) - } - - for _, txt := range invalidTxts { - It("text should be invalid", func() { - Expect(IsValidText(txt)).To(BeFalse()) - }) - } - }) - - Context("validateBucketsBoundaries", func() { - type bucketBoundaries struct { - start int - end int - } - type bucketTest struct { - maxTokens int - expectedBuckets []bucketBoundaries - } - - tests := []bucketTest{{500, []bucketBoundaries{{1, 20}, {21, 40}, {41, 60}, {61, 480}, {481, 499}}}, - {47, []bucketBoundaries{{1, 9}, {10, 18}, {19, 27}, {28, 36}, {37, 46}}}, - {50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}} - - for _, test := range tests { - Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1)) - - It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() { - for i := range len(cumulativeBucketsProbabilities) - 1 { - start, end := calcBucketBoundaries(test.maxTokens, i) - Expect(start).To(Equal(test.expectedBuckets[i].start)) - Expect(end).To(Equal(test.expectedBuckets[i].end)) - } - }) - } - }) }) diff --git a/pkg/dataset/.llm-d/test.invalid.column.sqlite3 b/pkg/dataset/.llm-d/test.invalid.column.sqlite3 new file mode 100644 index 00000000..b35ad60d Binary files /dev/null and b/pkg/dataset/.llm-d/test.invalid.column.sqlite3 differ diff --git a/pkg/dataset/.llm-d/test.invalid.sqlite3 b/pkg/dataset/.llm-d/test.invalid.sqlite3 new file mode 100644 index 00000000..cd087558 --- /dev/null +++ b/pkg/dataset/.llm-d/test.invalid.sqlite3 @@ -0,0 +1 @@ +Hello world! diff --git a/pkg/dataset/.llm-d/test.invalid.table.sqlite3 b/pkg/dataset/.llm-d/test.invalid.table.sqlite3 new file mode 100644 index 00000000..b059e36b Binary files /dev/null and b/pkg/dataset/.llm-d/test.invalid.table.sqlite3 differ diff --git a/pkg/dataset/.llm-d/test.invalid.type.sqlite3 b/pkg/dataset/.llm-d/test.invalid.type.sqlite3 new file mode 100644 index 00000000..dc84914e Binary files /dev/null and b/pkg/dataset/.llm-d/test.invalid.type.sqlite3 differ diff --git a/pkg/dataset/.llm-d/test.valid.sqlite3 b/pkg/dataset/.llm-d/test.valid.sqlite3 new file mode 100644 index 00000000..f6e6601e Binary files /dev/null and b/pkg/dataset/.llm-d/test.valid.sqlite3 differ diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go new file mode 100644 index 00000000..ecf7659a --- /dev/null +++ b/pkg/dataset/custom_dataset.go @@ -0,0 +1,505 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +type CustomDataset struct { + BaseDataset + db *sql.DB + hasWarned bool +} + +// use constants for expected column names and types +const ( + tableName = "llmd" + idCol = "id" + promptHashCol = "prompt_hash" + genTokensCol = "gen_tokens" + nGenTokensCol = "n_gen_tokens" + idColType = "INTEGER" + promptHashColType = "BLOB" + genTokensColType = "JSON" + nGenTokensColType = "INTEGER" + progressLogTimeInterval = 5 * time.Second + progressLogPercentInterval = 10 +) + +func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path string) error { + folder := filepath.Dir(path) + err := os.MkdirAll(folder, 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + if _, err := os.Stat(path); err == nil { + // file already exists + return errors.New("Dataset file already exists, should not download: " + path) + } + + out, err := os.Create(path) + if err != nil { + return err + } + defer func() { + cerr := out.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close file after download") + } + }() + + d.logger.Info("Using dataset-url", "dataset-url", url) + resp, err := http.Get(url) + if err != nil { + return err + } + defer func() { + cerr := resp.Body.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close response body after download") + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bad status: %s", resp.Status) + } + + // Progress reader with context + pr := &progressReader{ + Reader: resp.Body, + total: resp.ContentLength, + logger: d.logger, + ctx: ctx, + startTime: time.Now(), + } + + written, err := io.Copy(out, pr) + if err != nil { + // Remove incomplete file + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + // If context was cancelled, return a specific error + if errors.Is(err, context.Canceled) { + return errors.New("download cancelled by user") + } + return fmt.Errorf("failed to download file: %w", err) + } + // Check if file size is zero + if written == 0 { + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove empty file after download") + } + return errors.New("downloaded file is empty") + } + + // Ensure file is fully flushed and closed before returning success + if err := out.Sync(); err != nil { + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + return fmt.Errorf("failed to sync file: %w", err) + } + + return nil +} + +// progressReader wraps an io.Reader and logs download progress. +type progressReader struct { + io.Reader + total int64 + downloaded int64 + startTime time.Time + lastPct int + lastLogTime time.Time + logger logr.Logger + ctx context.Context +} + +func (pr *progressReader) Read(p []byte) (int, error) { + select { + case <-pr.ctx.Done(): + return 0, pr.ctx.Err() + default: + } + n, err := pr.Reader.Read(p) + pr.downloaded += int64(n) + if pr.total > 0 { + pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) + now := time.Now() + + timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds() + pctDiff := pct - pr.lastPct + + if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) { + // progress will be shown every interval seconds or every interval percent of progress + pr.logProgress(pct) + pr.lastPct = pct + pr.lastLogTime = now + } + } + return n, err +} + +func (pr *progressReader) logProgress(pct int) { + elapsedTime := time.Since(pr.startTime).Seconds() + speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime) + remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime) + if pct != 100 { + pr.logger.Info(fmt.Sprintf("Download progress: %d%%, Speed: %.2f MB/s, Remaining time: %.2fs", pct, speed, remainingTime)) + } else { + pr.logger.Info(fmt.Sprintf("Download completed: 100%%, Average Speed: %.2f MB/s, Total time: %.2fs", speed, elapsedTime)) + } +} + +func (d *CustomDataset) verifyDB() error { + rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");") + if err != nil { + return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) + } + defer func() { + if cerr := rows.Close(); cerr != nil { + d.logger.Error(cerr, "failed to close rows after querying table info") + } + }() + + expectedColumns := map[string]string{ + idCol: idColType, + promptHashCol: promptHashColType, + genTokensCol: genTokensColType, + nGenTokensCol: nGenTokensColType, + } + + columnsFound := make(map[string]bool) + + var ( + columnName string + columnType string + cid int + notnull int + dfltValue interface{} + pk int + ) + + for rows.Next() { + err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk) + if err != nil { + return fmt.Errorf("failed to scan table info row: %w", err) + } + if expectedType, exists := expectedColumns[columnName]; exists { + if columnType != expectedType { + return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType) + } + columnsFound[columnName] = true + } + } + + for col := range expectedColumns { + if !columnsFound[col] { + return fmt.Errorf("missing expected column in %s table: %s", tableName, col) + } + } + + return nil +} + +func (d *CustomDataset) getRecordsCount() (int, error) { + var count int + err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to query database: %w", err) + } + return count, nil +} + +func (d *CustomDataset) loadDatabaseInMemory(path string) error { + d.logger.Info("Loading database into memory...") + start := time.Now() + + // Create in-memory database + var err error + d.db, err = sql.Open("sqlite3", ":memory:") + if err != nil { + return fmt.Errorf("failed to create in-memory database: %w", err) + } + + // Use ATTACH to copy the database + attachSQL := fmt.Sprintf("ATTACH DATABASE '%s' AS source", path) + _, err = d.db.Exec(attachSQL) + if err != nil { + if closeErr := d.db.Close(); closeErr != nil { + d.logger.Error(closeErr, "failed to close in-memory database after attach failure") + } + d.db = nil + return fmt.Errorf("failed to attach source database: %w", err) + } + + // Copy the table structure first + _, err = d.db.Exec(`CREATE TABLE llmd ( + id INTEGER PRIMARY KEY, + prompt_hash BLOB, + gen_tokens JSON, + n_gen_tokens INTEGER + )`) + if err != nil { + if closeErr := d.db.Close(); closeErr != nil { + d.logger.Error(closeErr, "failed to close in-memory database after create table failure") + } + d.db = nil + return fmt.Errorf("failed to create table: %w", err) + } + + // Copy the data + _, err = d.db.Exec("INSERT INTO llmd SELECT * FROM source.llmd") + if err != nil { + if closeErr := d.db.Close(); closeErr != nil { + d.logger.Error(closeErr, "failed to close in-memory database after copy failure") + } + d.db = nil + return fmt.Errorf("failed to copy data: %w", err) + } + + // Detach the source database + _, err = d.db.Exec("DETACH DATABASE source") + if err != nil { + d.logger.Error(err, "failed to detach source database") + } + + loadTime := time.Since(start) + d.logger.Info("Database loaded into memory", "load_time", loadTime.String()) + return nil +} + +func (d *CustomDataset) connectToDB(path string, useInMemory bool) error { + if d.db != nil { + err := d.db.Close() + if err != nil { + d.logger.Error(err, "failed to close existing database connection") + } + d.db = nil + } + // check if file exists + _, err := os.Stat(path) + if err != nil { + return fmt.Errorf("database file does not exist: %w", err) + } + + if useInMemory { + err = d.loadDatabaseInMemory(path) + if err != nil { + return err + } + } else { + // Use file-based database (original behavior) + d.db, err = sql.Open("sqlite3", path) + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + + // Check if there are other connections to the database + _, err = d.db.Exec("BEGIN EXCLUSIVE;") + if err != nil { + if closeErr := d.db.Close(); closeErr != nil { + d.logger.Error(closeErr, "failed to close database after failing to acquire exclusive lock") + } + d.db = nil + return fmt.Errorf("database is locked or has other active connections: %w", err) + } + } + + err = d.verifyDB() + if err != nil { + return fmt.Errorf("failed to verify database: %w", err) + } + + count, err := d.getRecordsCount() + if err != nil { + d.logger.Error(err, "failed to get records count") + return fmt.Errorf("failed to query database: %w", err) + } + + if useInMemory { + d.logger.Info("In-memory database connected successfully", "path", path, "records count", count) + } else { + d.logger.Info("Database connected successfully", "path", path, "records count", count) + } + return nil +} + +func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error { + d.logger = logger + if path == "" { + return errors.New("no dataset path provided") + } + d.hasWarned = false + if url == "" { + d.logger.Info("Using dataset from", "path", path) + return d.connectToDB(path, useInMemory) + } + _, err := os.Stat(path) + if err != nil { + // file does not exist, download it + err = d.downloadDataset(ctx, url, path) + if err != nil { + // if the file is created but incomplete, remove it + if _, statErr := os.Stat(path); statErr == nil { + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + } + return fmt.Errorf("failed to download dataset: %w", err) + } + } + d.logger.Info("Using dataset path", "dataset-path", path) + + return d.connectToDB(path, useInMemory) +} + +func (d *CustomDataset) Close() error { + // Release db lock (only for file-based databases) + _, err := d.db.Exec("ROLLBACK;") + if err != nil { + if cerr := d.db.Close(); cerr != nil { + d.logger.Error(cerr, "failed to close database after failing to acquire exclusive lock") + } + d.db = nil + return fmt.Errorf("failed to release exclusive lock: %w", err) + } + + if d.db != nil { + return d.db.Close() + } + return nil +} + +func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { + var tokensList [][]string + for rows.Next() { + var tokensJSON string + if err := rows.Scan(&tokensJSON); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + var tokens []string + if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil { + return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err) + } + tokensList = append(tokensList, tokens) + } + return tokensList, nil +} + +func (d *CustomDataset) GetPromptHash(req openaiserverapi.CompletionRequest) []byte { + hashArray := sha256.Sum256([]byte(req.GetFullPrompt())) + return hashArray[:] +} + +func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string { + return hex.EncodeToString(hashBytes) +} + +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + if mode == common.ModeEcho { + return d.echo(req) + } + nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason) + return tokens, finishReason, err +} + +func (d *CustomDataset) query(query string, nTokens int) ([][]string, error) { + rows, err := d.db.Query(query) + if err != nil { + if !d.hasWarned { + d.logger.Error(err, "Failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") + d.hasWarned = true + } + return [][]string{GenPresetRandomTokens(nTokens)}, nil + } + defer func() { + if cerr := rows.Close(); cerr != nil { + d.logger.Error(cerr, "failed to close rows after query") + } + }() + return unmarshalAllRecords(rows) +} + +func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string) ([]string, error) { + // query by prompt hash first + promptHash := d.GetPromptHash(req) + promptHashHex := d.GetPromptHashHex(promptHash) + query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" + tokensList, err := d.query(query, nTokens) + + // filter out results according to finish reason + var filteredTokensList [][]string + if finishReason != LengthFinishReason && finishReason != StopFinishReason { + d.logger.Error(errors.New("unknown finish reason"), "Unexpected finish reason", "reason", finishReason) + } + for _, tokens := range tokensList { + if finishReason == StopFinishReason && len(tokens) <= nTokens { + filteredTokensList = append(filteredTokensList, tokens) + } else if finishReason == LengthFinishReason && len(tokens) == nTokens { + filteredTokensList = append(filteredTokensList, tokens) + } + } + tokensList = filteredTokensList + + if err != nil || len(filteredTokensList) == 0 { + switch finishReason { + case LengthFinishReason: + query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" + tokensList, err = d.query(query, nTokens) + case StopFinishReason: + query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";" + tokensList, err = d.query(query, nTokens) + } + } + + if err != nil || len(tokensList) == 0 { + // if both queries fail or return no results, generate random tokens + return GenPresetRandomTokens(nTokens), nil + } + if d.hasWarned { + d.hasWarned = false + } + randIndex := common.RandomInt(0, len(tokensList)-1) + return tokensList[randIndex], nil +} diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go new file mode 100644 index 00000000..afd734a2 --- /dev/null +++ b/pkg/dataset/custom_dataset_test.go @@ -0,0 +1,216 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "context" + "encoding/json" + "os" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/klog/v2" + + _ "github.com/mattn/go-sqlite3" +) + +const ( + testPrompt = "Hello world!" +) + +var _ = Describe("CustomDataset", Ordered, func() { + var ( + dataset *CustomDataset + file_folder string + path string + validDBPath string + pathToInvalidDB string + pathNotExist string + pathToInvalidTableDB string + pathToInvalidColumnDB string + pathToInvalidTypeDB string + ) + + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) + + BeforeEach(func() { + dataset = &CustomDataset{} + file_folder = ".llm-d" + path = file_folder + "/test.sqlite3" + err := os.MkdirAll(file_folder, os.ModePerm) + Expect(err).NotTo(HaveOccurred()) + validDBPath = file_folder + "/test.valid.sqlite3" + pathNotExist = file_folder + "/test.notexist.sqlite3" + pathToInvalidDB = file_folder + "/test.invalid.sqlite3" + pathToInvalidTableDB = file_folder + "/test.invalid.table.sqlite3" + pathToInvalidColumnDB = file_folder + "/test.invalid.column.sqlite3" + pathToInvalidTypeDB = file_folder + "/test.invalid.type.sqlite3" + }) + + AfterEach(func() { + if dataset.db != nil { + err := dataset.db.Close() + Expect(err).NotTo(HaveOccurred()) + } + }) + + It("should return error for invalid DB path", func() { + err := dataset.connectToDB("/invalid/path/to/db.sqlite", false) + Expect(err).To(HaveOccurred()) + }) + + It("should download file from url", func() { + // remove file if it exists + _, err := os.Stat(path) + if err == nil { + err = os.Remove(path) + Expect(err).NotTo(HaveOccurred()) + } + + url := "https://llm-d.ai" + err = dataset.downloadDataset(context.Background(), url, path) + Expect(err).NotTo(HaveOccurred()) + _, err = os.Stat(path) + Expect(err).NotTo(HaveOccurred()) + err = os.Remove(path) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should not download file from url", func() { + url := "https://256.256.256.256" // invalid url + err := dataset.downloadDataset(context.Background(), url, path) + Expect(err).To(HaveOccurred()) + }) + + It("should successfully init dataset", func() { + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + Expect(err).NotTo(HaveOccurred()) + + row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") + var n_gen_tokens int + err = row.Scan(&n_gen_tokens) + Expect(err).NotTo(HaveOccurred()) + Expect(n_gen_tokens).To(Equal(4)) + + var jsonStr string + row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") + err = row.Scan(&jsonStr) + Expect(err).NotTo(HaveOccurred()) + var tokens []string + err = json.Unmarshal([]byte(jsonStr), &tokens) + Expect(err).NotTo(HaveOccurred()) + Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + + }) + + It("should return error for non-existing DB path", func() { + err := dataset.connectToDB(pathNotExist, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("database file does not exist")) + }) + + It("should return error for invalid DB file", func() { + err := dataset.connectToDB(pathToInvalidDB, false) + Expect(err).To(HaveOccurred()) + }) + + It("should return error for DB with invalid table", func() { + err := dataset.connectToDB(pathToInvalidTableDB, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to verify database")) + }) + + It("should return error for DB with invalid column", func() { + err := dataset.connectToDB(pathToInvalidColumnDB, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("missing expected column")) + }) + + It("should return error for DB with invalid column type", func() { + err := dataset.connectToDB(pathToInvalidTypeDB, false) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("incorrect type")) + }) + + It("should return correct prompt hash in bytes", func() { + // b't\xbf\x14\xc0\x9c\x03\x83!\xcb\xa3\x97\x17\xda\xe1\xdcs(#\xaeJ\xbd\x8e\x15YY6v)\xa3\xc1\t\xa8' + expectedHashBytes := []byte{0x74, 0xbf, 0x14, 0xc0, 0x9c, 0x03, 0x83, 0x21, 0xcb, 0xa3, 0x97, 0x17, 0xda, 0xe1, 0xdc, 0x73, 0x28, 0x23, 0xae, 0x4a, 0xbd, 0x8e, 0x15, 0x59, 0x59, 0x36, 0x76, 0x29, 0xa3, 0xc1, 0x09, 0xa8} + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + + hashBytes := dataset.GetPromptHash(req) + Expect(hashBytes).To(Equal(expectedHashBytes)) + }) + + It("should return correct prompt hash in hex", func() { + expectedHashHex := "74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8" + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + + hashBytes := dataset.GetPromptHash(req) + hashHex := dataset.GetPromptHashHex(hashBytes) + Expect(hashHex).To(Equal(expectedHashHex)) + }) + + It("should return tokens for existing prompt", func() { + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + Expect(err).NotTo(HaveOccurred()) + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + Expect(finishReason).To(Equal(StopFinishReason)) + Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + }) + + It("should return at most 2 tokens for existing prompt", func() { + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + Expect(err).NotTo(HaveOccurred()) + n := int64(2) + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + MaxTokens: &n, + } + tokens, _, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + Expect(len(tokens)).To(BeNumerically("<=", 2)) + }) + + It("should successfully init dataset with in-memory option", func() { + err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", true) + Expect(err).NotTo(HaveOccurred()) + + req := &openaiserverapi.TextCompletionRequest{ + Prompt: testPrompt, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + Expect(finishReason).To(Equal(StopFinishReason)) + Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + }) +}) diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go new file mode 100644 index 00000000..41cb8354 --- /dev/null +++ b/pkg/dataset/dataset.go @@ -0,0 +1,334 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "context" + "errors" + "math" + "math/rand" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + _ "github.com/mattn/go-sqlite3" +) + +const ( + RoleAssistant = "assistant" + RoleUser = "user" +) + +const ( + ResponseLenMax = 128 + responseLenMean = 40 + responseLenStddev = 20 + stopFinishReasonProbability = 0.8 + + StopFinishReason = "stop" + LengthFinishReason = "length" + ToolsFinishReason = "tool_calls" + RemoteDecodeFinishReason = "remote_decode" +) + +// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response +var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} +var cumulativeBucketsProbabilities []float64 + +const ( + flexBucketIndex = 3 + maxFixedBucketSize = 20 +) + +// list of responses to use in random mode for completion requests +var chatCompletionFakeResponses = []string{ + `Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`, + `Testing, testing 1,2,3.`, + `I am fine, how are you today?`, + `I am your AI assistant, how can I help you today?`, + `Today is a nice sunny day.`, + `The temperature here is twenty-five degrees centigrade.`, + `Today it is partially cloudy and raining.`, + `To be or not to be that is the question.`, + `Alas, poor Yorick! I knew him, Horatio: A fellow of infinite jest`, + `The rest is silence. `, + `Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`, +} + +type Dataset interface { + // Init initializes the dataset using configs + Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error + // Close closes the dataset + Close() error + // GetTokens returns tokens for the given request and mode (echo or random) + GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) +} + +func init() { + cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) + sum := 0.0 + + for i, val := range respLenBucketsProbabilities { + sum += val + cumulativeBucketsProbabilities[i] = sum + } +} + +// GetRandomResponseLen returns int in range [1, responseLenMax] +// numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev +func GetRandomResponseLen() int { + for { + val := rand.NormFloat64()*responseLenStddev + responseLenMean + if val >= 1 && val <= ResponseLenMax { + return int(math.Round(val)) + } + // else reject and resample + } +} + +// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability +func GetRandomFinishReason() string { + if rand.Float64() < stopFinishReasonProbability { + return StopFinishReason + } + return LengthFinishReason +} + +// GenPresetRandomTokens generates random tokens for the required number of tokens, +// select randomly a sentence from chatCompletionFakeResponses, +// if number of tokens is lower than required - select another sentence, +// continue until the required number of tokens is achieved +func GenPresetRandomTokens(numOfTokens int) []string { + allTokens := make([]string, 0) + + for len(allTokens) < numOfTokens { + index := common.RandomInt(0, len(chatCompletionFakeResponses)-1) + // create tokens from text, splitting by spaces and special characters + tokens := common.Tokenize(chatCompletionFakeResponses[index]) + remaining := numOfTokens - len(allTokens) + + if len(tokens) > remaining { + // there is too many tokens, append only the relevant part + tokens = tokens[:remaining] + } + + if len(allTokens) > 0 { + // for not first sentences add space to the first token to separate between sentences without adding an additional token + tokens[0] = " " + tokens[0] + } + + allTokens = append(allTokens, tokens...) + } + + return allTokens +} + +// howManyTokensToGen generates the number of tokens to be returned in a response, and the finish reason (see constants) +// if maxCompletionTokens is defined +// - currently, the generated number of words in the text will be equal to it value +// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests +// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length +// if maxCompletionTokens is nil +// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters +// - finish reason is stop +// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens +// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined +func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool) (int, string) { + numOfTokens := 0 + finishReason := StopFinishReason + + // no max completion tokens, return text with random length + if maxCompletionTokens == nil { + numOfTokens = GetRandomResponseLen() + } else { + maxTokens := int(*maxCompletionTokens) + if ignore_eos { + numOfTokens = maxTokens + finishReason = LengthFinishReason + } else { + // max tokens is defined - generate real length of the response based on it + numOfTokens = getResponseLengthByHistogram(maxTokens) + if numOfTokens == maxTokens { + // if response should be create with maximum number of tokens - finish reason will be 'length' + finishReason = LengthFinishReason + } + } + } + + return numOfTokens, finishReason +} + +// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. +// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. +// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. +// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. +// Other values define probabilities for the equally sized buckets. +// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] +func getResponseLengthByHistogram(maxTokens int) int { + if maxTokens <= 1 { + return maxTokens + } + // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] + if maxTokens <= len(cumulativeBucketsProbabilities) { + res := common.RandomInt(1, maxTokens) + return res + } + + r := common.RandomFloat(0, 1) + + // check if r is in the last bucket, then maxTokens should be returned + if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { + return maxTokens + } + + // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use + // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 + bucketIndex := len(cumulativeBucketsProbabilities) - 1 + for i, c := range cumulativeBucketsProbabilities { + if r <= c { + bucketIndex = i + break + } + } + + // calculate the size of all of the buckets (except the special last bucket) + start, end := calcBucketBoundaries(maxTokens, bucketIndex) + + // pick uniformly within the bucket’s range + return common.RandomInt(start, end) +} + +// calcBucketBoundaries calculates boundaries of a bucket with the given index. +// Maximum size for equally sized buckets is defined by maxFixedBucketSize. +// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which +// division to equally size buckets will give buckets with size maxFixedBucketSize. +// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, +// all buckets will be of equal size, except the last bucket, which contains only one value. +// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], +// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. +// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), +// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. +func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { + maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1 + + if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 { + // create equally size buckets + // calculate the size of all of the buckets (except the special last bucket) + bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) + start = int(bucketSize*float64(bucketIndex)) + 1 + end = int(bucketSize * float64(bucketIndex+1)) + } else { + // create non-equally sized buckets and find boundaries of the required bucket + if bucketIndex < flexBucketIndex { + // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) + // start is the minimum number in the required bucket + start = maxFixedBucketSize*bucketIndex + 1 + end = maxFixedBucketSize * (bucketIndex + 1) + } else { + flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2)) + + if bucketIndex == flexBucketIndex { + // the relevant bucket is the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } else { + // the relevant bucket is one of buckets after the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } + } + } + + // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 + if end >= maxTokens { + end = maxTokens - 1 + } + + return start, end +} + +// EchoResponseTokens returns needed tokens, from a given text +// considering max completion tokens if it is not nil, and a finish reason (stop or length) +func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { + tokens := common.Tokenize(text) + // no max completion tokens, return entire text + if maxCompletionTokens == nil { + return tokens, StopFinishReason + } + + if *maxCompletionTokens >= int64(len(tokens)) { + return tokens, StopFinishReason + } + // return truncated text + return tokens[0:*maxCompletionTokens], LengthFinishReason +} + +type BaseDataset struct { + logger logr.Logger +} + +func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error { + d.logger = logger + return nil +} + +func (d *BaseDataset) Close() error { + return nil +} + +func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) { + nMaxTokens := d.extractMaxTokens(req) + prompt, err := d.extractPrompt(req) + if err != nil { + return nil, "", err + } + tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt) + return tokens, finishReason, nil +} + +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + if mode == common.ModeEcho { + return d.echo(req) + } + nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + return GenPresetRandomTokens(nTokensToGen), finishReason, nil +} + +// extractMaxTokens extracts the max tokens from the request +// for chat completion - max_completion_tokens field is used +// for text completion - max_tokens field is used +func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 { + if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { + return chatReq.GetMaxCompletionTokens() + } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { + return textReq.MaxTokens + } + return nil +} + +// extractPrompt extracts the prompt from the request +// for chat completion - the last user message is used as the prompt +// for text completion - the prompt field is used +func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) { + if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { + return chatReq.GetLastUserMsg(), nil + } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { + return textReq.GetPrompt(), nil + } + return "", errors.New("unknown request type") +} diff --git a/pkg/dataset/dataset_suite_test.go b/pkg/dataset/dataset_suite_test.go new file mode 100644 index 00000000..c9dea52c --- /dev/null +++ b/pkg/dataset/dataset_suite_test.go @@ -0,0 +1,13 @@ +package dataset_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestDataset(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Dataset Suite") +} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go new file mode 100644 index 00000000..83a2953b --- /dev/null +++ b/pkg/dataset/dataset_test.go @@ -0,0 +1,204 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "fmt" + "strings" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Dataset", Ordered, func() { + var ( + dataset *BaseDataset + ) + + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) + + BeforeEach(func() { + dataset = &BaseDataset{} + }) + + Context("GetRandomTokens", func() { + + It("should return complete text", func() { + req := &openaiserverapi.ChatCompletionRequest{} + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + text := strings.Join(tokens, "") + Expect(IsValidText(text)).To(BeTrue()) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + + It("should return short text", func() { + maxCompletionTokens := int64(2) + req := &openaiserverapi.ChatCompletionRequest{ + MaxCompletionTokens: &maxCompletionTokens, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + tokensCnt := int64(len(tokens)) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } + }) + + It("should return long text", func() { + // return required number of tokens although it is higher than ResponseLenMax + maxCompletionTokens := int64(ResponseLenMax * 5) + req := &openaiserverapi.ChatCompletionRequest{ + MaxTokens: &maxCompletionTokens, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + tokensCnt := int64(len(tokens)) + Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) + text := strings.Join(tokens, "") + Expect(IsValidText(text)).To(BeTrue()) + if tokensCnt == maxCompletionTokens { + Expect(finishReason).To(Equal(LengthFinishReason)) + } else { + Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) + Expect(finishReason).To(Equal(StopFinishReason)) + } + }) + + DescribeTable("should return exact num of tokens", + func(maxCompletionTokens int) { + n := int64(maxCompletionTokens) + req := &openaiserverapi.ChatCompletionRequest{ + BaseCompletionRequest: openaiserverapi.BaseCompletionRequest{ + IgnoreEOS: true, + }, + MaxTokens: &n, + } + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) + Expect(err).ShouldNot(HaveOccurred()) + nGenTokens := int64(len(tokens)) + Expect(nGenTokens).Should(Equal(n)) + Expect(finishReason).To(Equal(LengthFinishReason)) + }, + func(maxCompletionTokens int) string { + return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) + }, + Entry("1", 1), + Entry("42", 42), + Entry("99", 99), + Entry("10000", 10000), + ) + }) + + Context("GetResponseTokens", func() { + theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" + theTokens := common.Tokenize(theText) + + It("should return the same text since max tokens is not defined", func() { + tokens, finishReason := EchoResponseTokens(nil, theText) + Expect(tokens).Should(Equal(theTokens)) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + It("should return the same text since max tokens is higher than the text length", func() { + maxCompletionTokens := int64(1000) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) + Expect(tokens).Should(Equal(theTokens)) + Expect(finishReason).Should(Equal(StopFinishReason)) + }) + It("should return partial text", func() { + maxCompletionTokens := int64(2) + tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) + Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) + Expect(finishReason).Should(Equal(LengthFinishReason)) + }) + }) + + Context("GetRandomTokens", func() { + lenArr := []int{5, 20, 50, 150} + + for _, len := range lenArr { + name := fmt.Sprintf("should return text with %d tokens", len) + It(name, func() { + tokens := GenPresetRandomTokens(len) + Expect(tokens).Should(HaveLen(len)) + }) + } + }) + + Context("IsValidText", func() { + validTxts := make([]string, 0) + invalidTxts := make([]string, 0) + + validTxts = append(validTxts, chatCompletionFakeResponses[0][:4]) + validTxts = append(validTxts, chatCompletionFakeResponses[1]) + validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + + invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2]) + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ") + invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + + for _, txt := range validTxts { + It("text should be valid", func() { + Expect(IsValidText(txt)).To(BeTrue()) + }) + } + + for _, txt := range invalidTxts { + It("text should be invalid", func() { + Expect(IsValidText(txt)).To(BeFalse()) + }) + } + }) + + Context("validateBucketsBoundaries", func() { + type bucketBoundaries struct { + start int + end int + } + type bucketTest struct { + maxTokens int + expectedBuckets []bucketBoundaries + } + + tests := []bucketTest{{500, []bucketBoundaries{{1, 20}, {21, 40}, {41, 60}, {61, 480}, {481, 499}}}, + {47, []bucketBoundaries{{1, 9}, {10, 18}, {19, 27}, {28, 36}, {37, 46}}}, + {50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}} + + for _, test := range tests { + Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1)) + + It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() { + for i := range len(cumulativeBucketsProbabilities) - 1 { + start, end := calcBucketBoundaries(test.maxTokens, i) + Expect(start).To(Equal(test.expectedBuckets[i].start)) + Expect(end).To(Equal(test.expectedBuckets[i].end)) + } + }) + } + }) +}) diff --git a/pkg/common/test_helpers.go b/pkg/dataset/test_helpers.go similarity index 98% rename from pkg/common/test_helpers.go rename to pkg/dataset/test_helpers.go index 31ff4bd5..ed6b6f1b 100644 --- a/pkg/common/test_helpers.go +++ b/pkg/dataset/test_helpers.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package common +package dataset import "strings" diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 1c6284a1..6384f28d 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -34,7 +34,6 @@ import ( ) func (s *VllmSimulator) newListener() (net.Listener, error) { - s.logger.Info("Server starting", "port", s.config.Port) listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) if err != nil { return nil, err diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index ab55fea2..e5d70ede 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -32,6 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" @@ -115,6 +116,8 @@ type VllmSimulator struct { pod string // tokenizer is currently used in kv-cache and in /tokenize tokenizer tokenization.Tokenizer + // dataset is used for token generation in responses + dataset dataset.Dataset } // New creates a new VllmSimulator instance with the given logger @@ -213,6 +216,11 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { go s.kvcacheHelper.Run(ctx) } + err = s.initDataset(ctx) + if err != nil { + return fmt.Errorf("dataset initialization error: %w", err) + } + // run request processing workers for i := 1; i <= s.config.MaxNumSeqs; i++ { go s.reqProcessingWorker(ctx, i) @@ -222,13 +230,44 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { listener, err := s.newListener() if err != nil { - return err + s.logger.Error(err, "Failed to create listener") + return fmt.Errorf("listener creation error: %w", err) } // start the http server with context support return s.startServer(ctx, listener) } +func (s *VllmSimulator) initDataset(ctx context.Context) error { + randDataset := &dataset.BaseDataset{} + err := randDataset.Init(ctx, s.logger, "", "", false) + if err != nil { + return fmt.Errorf("failed to initialize random dataset: %w", err) + } + + if s.config.DatasetPath == "" && s.config.DatasetURL == "" { + s.logger.Info("No dataset path or URL provided, using random text for responses") + s.dataset = randDataset + return nil + } + + custDataset := &dataset.CustomDataset{} + err = custDataset.Init(ctx, s.logger, s.config.DatasetPath, s.config.DatasetURL, s.config.DatasetInMemory) + + if err == nil { + s.dataset = custDataset + return nil + } + + if strings.HasPrefix(err.Error(), "database is locked") { + s.logger.Info("Database is locked by another process, will use preset text for responses instead") + s.dataset = randDataset + return nil + } + + return err +} + // Print prints to a log, implementation of fasthttp.Logger func (s *VllmSimulator) Printf(format string, args ...interface{}) { s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...)) @@ -316,13 +355,15 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { if reqCtx.IsChatCompletion && req.GetToolChoice() != openaiserverapi.ToolChoiceNone && req.GetTools() != nil { - toolCalls, finishReason, completionTokens, err = + toolCalls, completionTokens, err = openaiserverapi.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) + finishReason = dataset.ToolsFinishReason } if toolCalls == nil && err == nil { // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. - responseTokens, finishReason, completionTokens, err = req.CreateResponseText(s.config.Mode) + responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) + completionTokens += len(responseTokens) } if err != nil { prefix := "" @@ -358,7 +399,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { } else { if req.IsDoRemoteDecode() { // in case this is prefill pod processing, return special finish reason - finishReason = common.RemoteDecodeFinishReason + finishReason = dataset.RemoteDecodeFinishReason } s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index e504c5d5..e8e57e41 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -28,6 +28,7 @@ import ( "strings" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" . "github.com/onsi/ginkgo/v2" @@ -117,6 +118,11 @@ func startServerWithArgs(ctx context.Context, mode string, args []string, envs m go s.kvcacheHelper.Run(ctx) } + err = s.initDataset(ctx) + if err != nil { + return nil, fmt.Errorf("dataset initialization error: %w", err) + } + // calculate number of tokens for user message, // must be activated after parseCommandParamsAndLoadConfig since it initializes the random engine userMsgTokens = int64(len(common.Tokenize(userMessage))) @@ -190,7 +196,7 @@ var _ = Describe("Simulator", func() { msg := strings.Join(tokens, "") if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(msg)).To(BeTrue()) + Expect(dataset.IsValidText(msg)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(msg).Should(Equal(userMessage)) @@ -239,7 +245,7 @@ var _ = Describe("Simulator", func() { text := strings.Join(tokens, "") if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(text)).To(BeTrue()) + Expect(dataset.IsValidText(text)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(text).Should(Equal(userMessage)) @@ -300,7 +306,7 @@ var _ = Describe("Simulator", func() { } else { if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(msg)).To(BeTrue()) + Expect(dataset.IsValidText(msg)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(msg).Should(Equal(userMessage)) @@ -371,7 +377,7 @@ var _ = Describe("Simulator", func() { } else { if mode == common.ModeRandom { // in case of random mode ensure that the returned message could be output of the random text generator - Expect(common.IsValidText(text)).To(BeTrue()) + Expect(dataset.IsValidText(text)).To(BeTrue()) } else { // in case of echo mode check that the text is returned as-is Expect(text).Should(Equal(userMessage)) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 2508298d..c64affc8 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -23,6 +23,7 @@ import ( "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" ) @@ -124,7 +125,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk var finishReasonToSend *string - if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) { + if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { finishReasonToSend = &finishReason } if context.isChatCompletion { @@ -141,7 +142,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // send the last chunk if finish reason is stop var chunk openaiserverapi.CompletionRespChunk - if finishReason == common.StopFinishReason { + if finishReason == dataset.StopFinishReason { if context.isChatCompletion { chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason) } else { diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index ae22a7f6..bffb7eea 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/llm-d/llm-d-inference-sim/pkg/common" + "github.com/llm-d/llm-d-inference-sim/pkg/dataset" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" @@ -365,7 +366,7 @@ var _ = Describe("Simulator for request with tools", func() { for _, choice := range chunk.Choices { if choice.Delta.Role != "" { role = choice.Delta.Role - } else if choice.FinishReason == "" || choice.FinishReason == common.ToolsFinishReason { + } else if choice.FinishReason == "" || choice.FinishReason == dataset.ToolsFinishReason { toolCalls := choice.Delta.ToolCalls Expect(toolCalls).To(HaveLen(1)) tc := toolCalls[0] diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index e7d5fb3b..34db0ee6 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -33,10 +33,6 @@ const ( type CompletionRequest interface { // GetRequestID returns the unique request id GetRequestID() string - // CreateResponseText creates and returns response payload based on this request, - // i.e., an array of generated tokens, the finish reason, and the number of created - // tokens - CreateResponseText(mode string) ([]string, string, int, error) // IsStream returns boolean that defines is response should be streamed IsStream() bool // GetModel returns model name as defined in the request @@ -69,10 +65,12 @@ type CompletionRequest interface { // when the field is true, the prefill phase should be done on remote pod, // whereas decode phase is done on local pod, thus this is a decode request IsDoRemotePrefill() bool + // GetFullPrompt returns the full prompt including system and user prompts + GetFullPrompt() string } -// baseCompletionRequest contains base completion request related information -type baseCompletionRequest struct { +// BaseCompletionRequest contains base completion request related information +type BaseCompletionRequest struct { // RequestID is the unique id of this request RequestID string // Stream is a boolean value, defines whether response should be sent as a Stream @@ -105,44 +103,44 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage"` } -func (b *baseCompletionRequest) GetRequestID() string { +func (b *BaseCompletionRequest) GetRequestID() string { return b.RequestID } -func (b *baseCompletionRequest) IsStream() bool { +func (b *BaseCompletionRequest) IsStream() bool { return b.Stream } -func (b *baseCompletionRequest) GetModel() string { +func (b *BaseCompletionRequest) GetModel() string { return b.Model } -func (b *baseCompletionRequest) IncludeUsage() bool { +func (b *BaseCompletionRequest) IncludeUsage() bool { return !b.Stream || b.StreamOptions.IncludeUsage } -func (b *baseCompletionRequest) IsDoRemoteDecode() bool { +func (b *BaseCompletionRequest) IsDoRemoteDecode() bool { return b.DoRemoteDecode } -func (b *baseCompletionRequest) IsDoRemotePrefill() bool { +func (b *BaseCompletionRequest) IsDoRemotePrefill() bool { return b.DoRemotePrefill } // GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are // in the local KV Cache -func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int { +func (b *BaseCompletionRequest) GetNumberOfCachedPromptTokens() int { return b.cachedPromptTokens } // GetIgnoreEOS returns the value of IgnoreEOS -func (b *baseCompletionRequest) GetIgnoreEOS() bool { +func (b *BaseCompletionRequest) GetIgnoreEOS() bool { return b.IgnoreEOS } // SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are // in the local KV Cache -func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { +func (b *BaseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { b.cachedPromptTokens = cachedPromptTokens } @@ -157,7 +155,7 @@ type CompletionReqCtx struct { // ChatCompletionRequest defines structure of /chat/completion request type ChatCompletionRequest struct { - baseCompletionRequest + BaseCompletionRequest // Messages list of request's Messages Messages []Message `json:"messages"` @@ -230,7 +228,7 @@ func (c *ChatCompletionRequest) GetMaxCompletionTokens() *int64 { // getLastUserMsg returns last message from this request's messages with user role, // if does not exist - returns an empty string -func (req *ChatCompletionRequest) getLastUserMsg() string { +func (req *ChatCompletionRequest) GetLastUserMsg() string { for i := len(req.Messages) - 1; i >= 0; i-- { if req.Messages[i].Role == RoleUser { return req.Messages[i].Content.PlainText() @@ -240,30 +238,25 @@ func (req *ChatCompletionRequest) getLastUserMsg() string { return "" } -// CreateResponseText creates and returns response payload based on this request, -// i.e., an array of generated tokens, the finish reason, and the number of created -// tokens -func (req ChatCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - maxTokens, err := common.GetMaxTokens(req.MaxCompletionTokens, req.MaxTokens) - if err != nil { - return nil, "", 0, err - } - - var text, finishReason string - if mode == common.ModeEcho { - text, finishReason = common.GetResponseText(maxTokens, req.getLastUserMsg()) - } else { - text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) +func (req *ChatCompletionRequest) GetFullPrompt() string { + prompt := "" + for _, msg := range req.Messages { + switch msg.Role { + case RoleUser: + prompt += "### user:\n" + msg.Content.Raw + "\n" + case RoleAssistant: + prompt += "### assistant:\n" + msg.Content.Raw + "\n" + default: + prompt += "### unknown:\n" + msg.Content.Raw + "\n" + } } - - tokens := common.Tokenize(text) - return tokens, finishReason, len(tokens), nil + return prompt } // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { - baseCompletionRequest + BaseCompletionRequest // Prompt defines request's content Prompt string `json:"prompt"` @@ -295,22 +288,6 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { return c.MaxTokens } -// CreateResponseText creates and returns response payload based on this request, -// i.e., an array of generated tokens, the finish reason, and the number of created -// tokens -func (req TextCompletionRequest) CreateResponseText(mode string) ([]string, string, int, error) { - maxTokens, err := common.GetMaxTokens(nil, req.MaxTokens) - if err != nil { - return nil, "", 0, err - } - - var text, finishReason string - if mode == common.ModeEcho { - text, finishReason = common.GetResponseText(maxTokens, req.Prompt) - } else { - text, finishReason = common.GetRandomResponseText(maxTokens, req.GetIgnoreEOS()) - } - - tokens := common.Tokenize(text) - return tokens, finishReason, len(tokens), nil +func (t *TextCompletionRequest) GetFullPrompt() string { + return "### user:\n" + t.Prompt + "\n" } diff --git a/pkg/openai-server-api/tools_utils.go b/pkg/openai-server-api/tools_utils.go index 3546aa9d..58f3a0df 100644 --- a/pkg/openai-server-api/tools_utils.go +++ b/pkg/openai-server-api/tools_utils.go @@ -55,7 +55,7 @@ var fakeStringArguments = []string{ // CreateToolCalls creates and returns response payload based on this request // (tool calls or nothing in case we randomly choose not to generate calls), // and the number of generated completion token sand the finish reason -func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, string, int, error) { +func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, int, error) { // This function is called if tool choice is either 'required' or 'auto'. // In case of 'required' at least one tool call has to be created, and we randomly choose // the number of calls starting from one. Otherwise, we start from 0, and in case we randomly @@ -66,7 +66,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati } numberOfCalls := common.RandomInt(min, len(tools)) if numberOfCalls == 0 { - return nil, "", 0, nil + return nil, 0, nil } calls := make([]ToolCall, 0) @@ -75,11 +75,11 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati index := common.RandomInt(0, len(tools)-1) args, err := GenerateToolArguments(tools[index], config) if err != nil { - return nil, "", 0, err + return nil, 0, err } argsJson, err := json.Marshal(args) if err != nil { - return nil, "", 0, err + return nil, 0, err } call := ToolCall{ @@ -95,7 +95,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati calls = append(calls, call) } - return calls, common.ToolsFinishReason, CountTokensForToolCalls(calls), nil + return calls, CountTokensForToolCalls(calls), nil } func GetRequiredAsMap(property map[string]any) map[string]struct{} {