diff --git a/README.md b/README.md index b162c7b..6d1a80b 100644 --- a/README.md +++ b/README.md @@ -362,3 +362,53 @@ curl -X POST http://localhost:8000/v1/chat/completions \ ] }' ``` + +## Response generation + +The `/v1/completions` and `/v1/chat/completions` endpoints produce responses based on simulator configurations and the specific request parameters. + +### Echo mode +In `echo` mode, responses always mirror the request content. Parameters `max_tokens`, `max_completions_tokens` or `ignore_eos` are ignored in this mode. + +### Random mode +In `random` mode, the fields `max_tokens`, `max_completions_tokens` and `ignore_eos` from the request are used during response generation. + +#### Use predefined texts for response generation +The simulator can generate responses from a predefined list of sentences. +If `max_tokens` or`max_completions_tokens` is specified, the response length is caclulated using a histogramwith six buckets and the following probabilities: 20%, 30%, 20%, 5%, 10%, 15%. +For a maximum length ≤ 120, bucket sizes are equal. +For a maximum length > 120, all buckets except forth are of size 20; +the forth bucket covers the remaining range. +After buckets are set, response length is sampled according to these probabilities. + + +Exmaples:
+max-len = 120: buckets are 1-20, 21-40, 41-60, 61-80, 81-100, 101-120.
+max-len = 200: buckets are 1-20, 21-40, 41-60, 61-160, 161-180, 181-200.
+ +If the maximum response length is not specified, it defaults to `-`. +In this case, response length is sampled from a Gaussian distribution with mean 40 and standard deviation 20. + + +After determining the response length: + +A random sentence from the predefined list is chosen and trimmed if it exceeds the required length. +If the sentence is shorter, additional random sentences are concatenated until the required token count is met. + +If `ignore_eos` is true, the response always reaches the maximum allowed length. + +The finish_reason is set to LENGTH if the response length equals the maximum; otherwise, it is set to STOP. + + +#### Use responses dataset for response generation +If `dataset-url` is set in command line, the dataset is downloaded to the location specified by `dataset-path`. + +If a valid dataset exists in the `dataset-path`, it is used for response selection: +The request prompt is hashed, and this value is matched against dataset entries. +If all matches are longer, a random match is selected and then trimmed. + +If `ignore_eos` is true is true and no match meets the required length, the response is completed with random tokens from the predefined list. + +If the prompt hash is not present in the dataset, a random response of length ≤ maximum is selected; +if all responses are longer, a random response is chosen and trimmed. + diff --git a/pkg/common/config.go b/pkg/common/config.go index 7929b2b..bc51be9 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -660,6 +660,10 @@ func (c *Configuration) validate() error { return errors.New("dataset-path is required when dataset-url is set") } + if c.Mode == ModeEcho && (c.DatasetPath != "" || c.DatasetURL != "") { + return errors.New("dataset cannot be defined in echo mode") + } + return nil } diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index ac5a42e..37cb180 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -532,6 +532,12 @@ var _ = Describe("Simulator configuration", func() { "--config", "../../manifests/config.yaml"}, expectedError: "fake metrics request-max-generation-tokens cannot contain negative values", }, + { + name: "invalid echo mode with dataset", + args: []string{"random", "--model", "test", "--dataset-path", "my/path", + "--mode", "echo"}, + expectedError: "dataset cannot be defined in echo mode", + }, } for _, test := range invalidTests { diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index 557321d..dc6a552 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -19,17 +19,8 @@ package dataset import ( "context" "crypto/sha256" - "database/sql" "encoding/hex" - "encoding/json" "errors" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "time" "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" @@ -39,469 +30,145 @@ import ( type CustomDataset struct { BaseDataset - db *sql.DB - hasWarned bool + sqliteHelper *sqliteHelper } -// use constants for expected column names and types -const ( - tableName = "llmd" - idCol = "id" - promptHashCol = "prompt_hash" - genTokensCol = "gen_tokens" - nGenTokensCol = "n_gen_tokens" - idColType = "INTEGER" - promptHashColType = "BLOB" - genTokensColType = "JSON" - nGenTokensColType = "INTEGER" - progressLogTimeInterval = 5 * time.Second - progressLogPercentInterval = 10 -) - -func (d *CustomDataset) downloadDataset(ctx context.Context, url string, path string) error { - folder := filepath.Dir(path) - err := os.MkdirAll(folder, 0755) - if err != nil { - return fmt.Errorf("failed to create parent directory: %w", err) - } - - if _, err := os.Stat(path); err == nil { - // file already exists - return errors.New("Dataset file already exists, should not download: " + path) - } - - out, err := os.Create(path) - if err != nil { +func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, random *common.Random, + path string, useInMemory bool, maxModelLen int) error { + if err := d.BaseDataset.Init(ctx, logger, random, maxModelLen); err != nil { return err } - defer func() { - cerr := out.Close() - if cerr != nil { - d.logger.Error(cerr, "failed to close file after download") - } - }() - - d.logger.V(logging.INFO).Info("Using dataset-url", "dataset-url", url) - resp, err := http.Get(url) - if err != nil { - return err - } - defer func() { - cerr := resp.Body.Close() - if cerr != nil { - d.logger.Error(cerr, "failed to close response body after download") - } - }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) - } - - // Progress reader with context - pr := &progressReader{ - Reader: resp.Body, - total: resp.ContentLength, - logger: d.logger, - ctx: ctx, - startTime: time.Now(), - } - - written, err := io.Copy(out, pr) - if err != nil { - // Remove incomplete file - cerr := os.Remove(path) - if cerr != nil { - d.logger.Error(cerr, "failed to remove incomplete file after download") - } - // If context was cancelled, return a specific error - if errors.Is(err, context.Canceled) { - return errors.New("download cancelled by user") - } - return fmt.Errorf("failed to download file: %w", err) - } - // Check if file size is zero - if written == 0 { - cerr := os.Remove(path) - if cerr != nil { - d.logger.Error(cerr, "failed to remove empty file after download") - } - return errors.New("downloaded file is empty") - } - - // Ensure file is fully flushed and closed before returning success - if err := out.Sync(); err != nil { - cerr := os.Remove(path) - if cerr != nil { - d.logger.Error(cerr, "failed to remove incomplete file after download") - } - return fmt.Errorf("failed to sync file: %w", err) + if path == "" { + return errors.New("no dataset path provided") } - return nil + d.sqliteHelper = newSqliteHelper(logger) + d.logger.V(logging.INFO).Info("Using dataset from", "path", path) + return d.sqliteHelper.connectToDB(path, useInMemory) } -// progressReader wraps an io.Reader and logs download progress. -type progressReader struct { - io.Reader - total int64 - downloaded int64 - startTime time.Time - lastPct int - lastLogTime time.Time - logger logr.Logger - ctx context.Context -} - -func (pr *progressReader) Read(p []byte) (int, error) { - select { - case <-pr.ctx.Done(): - return 0, pr.ctx.Err() - default: - } - n, err := pr.Reader.Read(p) - pr.downloaded += int64(n) - if pr.total > 0 { - pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) - now := time.Now() - - timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds() - pctDiff := pct - pr.lastPct - - if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) { - // progress will be shown every interval seconds or every interval percent of progress - pr.logProgress(pct) - pr.lastPct = pct - pr.lastLogTime = now - } - } - return n, err +func (d *CustomDataset) getPromptHash(req openaiserverapi.CompletionRequest) []byte { + hashArray := sha256.Sum256([]byte(req.GetFullPrompt())) + return hashArray[:] } -func (pr *progressReader) logProgress(pct int) { - elapsedTime := time.Since(pr.startTime).Seconds() - speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime) - remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime) - if pct != 100 { - pr.logger.V(logging.INFO).Info("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime) - } else { - pr.logger.V(logging.INFO).Info("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime) - } +func (d *CustomDataset) getPromptHashHex(hashBytes []byte) string { + return hex.EncodeToString(hashBytes) } -func (d *CustomDataset) verifyDB() error { - rows, err := d.db.Query("PRAGMA table_info(" + tableName + ");") - if err != nil { - return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) - } - defer func() { - if cerr := rows.Close(); cerr != nil { - d.logger.Error(cerr, "failed to close rows after querying table info") +// categorizeResponses receives list of responses tokens and maximum response length +// categorize responses to three collections: +// - shorter or equal length to maxLen +// - exact maxLen length +// - longer than maxLen +func (d *CustomDataset) categorizeResponses(responses [][]string, maxLen int) (shorterOrEqLen [][]string, equalLen [][]string, longerLen [][]string) { + for _, respTokens := range responses { + switch { + case len(respTokens) == maxLen: + shorterOrEqLen = append(shorterOrEqLen, respTokens) + equalLen = append(equalLen, respTokens) + case len(respTokens) < maxLen: + shorterOrEqLen = append(shorterOrEqLen, respTokens) + default: + longerLen = append(longerLen, respTokens) } - }() - - expectedColumns := map[string]string{ - idCol: idColType, - promptHashCol: promptHashColType, - genTokensCol: genTokensColType, - nGenTokensCol: nGenTokensColType, } - - columnsFound := make(map[string]bool) - - var ( - columnName string - columnType string - cid int - notnull int - dfltValue interface{} - pk int - ) - - for rows.Next() { - err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk) - if err != nil { - return fmt.Errorf("failed to scan table info row: %w", err) - } - if expectedType, exists := expectedColumns[columnName]; exists { - if columnType != expectedType { - return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType) - } - columnsFound[columnName] = true - } - } - - for col := range expectedColumns { - if !columnsFound[col] { - return fmt.Errorf("missing expected column in %s table: %s", tableName, col) - } - } - - return nil + return } -func (d *CustomDataset) getRecordsCount() (int, error) { - var count int - err := d.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) - if err != nil { - return 0, fmt.Errorf("failed to query database: %w", err) - } - return count, nil +// getRandomResponse returns a randomly selected element from the given array, array is not empty +func (d *CustomDataset) getRandomResponse(responses [][]string) []string { + return responses[d.random.RandomInt(0, len(responses)-1)] } -func (d *CustomDataset) loadDatabaseInMemory(path string) error { - d.logger.V(logging.INFO).Info("Loading database into memory...") - start := time.Now() - - // Create in-memory database - var err error - d.db, err = sql.Open("sqlite3", ":memory:") - if err != nil { - return fmt.Errorf("failed to create in-memory database: %w", err) - } - - // Use ATTACH to copy the database - attachSQL := fmt.Sprintf("ATTACH DATABASE '%s' AS source", path) - _, err = d.db.Exec(attachSQL) - if err != nil { - if closeErr := d.db.Close(); closeErr != nil { - d.logger.Error(closeErr, "failed to close in-memory database after attach failure") - } - d.db = nil - return fmt.Errorf("failed to attach source database: %w", err) - } - - // Copy the table structure first - _, err = d.db.Exec(`CREATE TABLE llmd ( - id INTEGER PRIMARY KEY, - prompt_hash BLOB, - gen_tokens JSON, - n_gen_tokens INTEGER - )`) - if err != nil { - if closeErr := d.db.Close(); closeErr != nil { - d.logger.Error(closeErr, "failed to close in-memory database after create table failure") - } - d.db = nil - return fmt.Errorf("failed to create table: %w", err) - } - - // Copy the data - _, err = d.db.Exec("INSERT INTO llmd SELECT * FROM source.llmd") - if err != nil { - if closeErr := d.db.Close(); closeErr != nil { - d.logger.Error(closeErr, "failed to close in-memory database after copy failure") - } - d.db = nil - return fmt.Errorf("failed to copy data: %w", err) - } - - // Detach the source database - _, err = d.db.Exec("DETACH DATABASE source") - if err != nil { - d.logger.Error(err, "failed to detach source database") - } - - loadTime := time.Since(start) - d.logger.V(logging.INFO).Info("Database loaded into memory", "load_time", loadTime.String()) - return nil -} - -func (d *CustomDataset) connectToDB(path string, useInMemory bool) error { - if d.db != nil { - err := d.db.Close() - if err != nil { - d.logger.Error(err, "failed to close existing database connection") - } - d.db = nil - } - // check if file exists - _, err := os.Stat(path) - if err != nil { - return fmt.Errorf("database file does not exist: %w", err) - } - - if useInMemory { - err = d.loadDatabaseInMemory(path) - if err != nil { - return err +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +// In echo mode the prompt is returned. +// In random mode follow this steps: +// Calculate maximum length of response (basedon max-tokens or max-completions-tokens or model-len) +// If dataset contains responses for the given prompt, and there are responses with length <= +// max response length - use random one from the list, +// otherwise select random one from the longer responses and trim it as required +// If no responses were found in the dataset for the given prompt, +// get random record fromn the dataset with response length equal or lower than max response length, +// if there is no records shorter/equal to max length - get random response from the dataset +// and trim it to the required length +// if ignore_eos=true the response always will have the max response len tokens, missing tokens +// are randomly selected from the hard-coded collection +func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + if mode == common.ModeEcho { + return d.getTokensInEchoMode(req) + } + + maxResponseLen, _ := d.calculateResponseMaxLen(req) + responseTokens := []string{} + + // get all records for the hashes prompt + promptHash := d.getPromptHash(req) + promptHashHex := d.getPromptHashHex(promptHash) + responses, err := d.sqliteHelper.getResponsesForPrompt(promptHashHex) + if err != nil { + return responseTokens, "", err + } + + if len(responses) > 0 { + // has responses for the given request + d.logger.V(logging.TRACE).Info("Reponses were found in the dataset for the request's prompt") + shorterOrEqLenResponses, equalLenResponses, longerLenResponses := d.categorizeResponses(responses, maxResponseLen) + + if req.GetIgnoreEOS() { + // must return response with exactly calculated max length + switch { + case len(equalLenResponses) > 0: + // has responses with required length - return randomly selected response + responseTokens = d.getRandomResponse(equalLenResponses) + case len(longerLenResponses) > 0: + // has responses longer than required - return randomly selected trimmed response + responseTokens = d.getRandomResponse(longerLenResponses)[:maxResponseLen] + default: + // all responses are shorter than required, select randomly and pad with random tokens + responseTokens = d.getRandomResponse(shorterOrEqLenResponses) + responseTokens = append(responseTokens, d.generatePresetRandomTokens(maxResponseLen-len(responseTokens))...) + } + } else { + // has responses for the request, return response shorter or equal to the maxReponsesLen + // finishReason = common.LengthFinishReason + if len(shorterOrEqLenResponses) > 0 { + // has responses shorter or equal length than required - return randomly selected response + responseTokens = d.getRandomResponse(shorterOrEqLenResponses) + } else { + // all responses are longer than required, use randomly sleected trimmed response + responseTokens = d.getRandomResponse(longerLenResponses)[:maxResponseLen] + } } } else { - // Use file-based database (original behavior) - d.db, err = sql.Open("sqlite3", path) + // no resopnses for the given request + d.logger.V(logging.TRACE).Info("No reponses in the dataset for the request's prompt") + // try to find a random response with number of tokens <= tokens limit + randomResponses, err := d.sqliteHelper.getResponsesForLen(maxResponseLen) if err != nil { - return fmt.Errorf("failed to open database: %w", err) + return responseTokens, "", err } - - // Check if there are other connections to the database - _, err = d.db.Exec("BEGIN EXCLUSIVE;") - if err != nil { - if closeErr := d.db.Close(); closeErr != nil { - d.logger.Error(closeErr, "failed to close database after failing to acquire exclusive lock") + if len(randomResponses) == 0 { + // failed to get response with number of tokens <= tokensLimit, get response with any number of tokens + randomResponses, err = d.sqliteHelper.getRandomResponse() + if err != nil { + return responseTokens, "", err } - d.db = nil - return fmt.Errorf("database is locked or has other active connections: %w", err) - } - } - - err = d.verifyDB() - if err != nil { - return fmt.Errorf("failed to verify database: %w", err) - } - - count, err := d.getRecordsCount() - if err != nil { - d.logger.Error(err, "failed to get records count") - return fmt.Errorf("failed to query database: %w", err) - } - - if useInMemory { - d.logger.V(logging.INFO).Info("In-memory database connected successfully", "path", path, "records count", count) - } else { - d.logger.V(logging.INFO).Info("Database connected successfully", "path", path, "records count", count) - } - return nil -} - -func (d *CustomDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error { - d.logger = logger - if path == "" { - return errors.New("no dataset path provided") - } - d.hasWarned = false - if url == "" { - d.logger.V(logging.INFO).Info("Using dataset from", "path", path) - return d.connectToDB(path, useInMemory) - } - _, err := os.Stat(path) - if err != nil { - // file does not exist, download it - err = d.downloadDataset(ctx, url, path) - if err != nil { - // if the file is created but incomplete, remove it - if _, statErr := os.Stat(path); statErr == nil { - cerr := os.Remove(path) - if cerr != nil { - d.logger.Error(cerr, "failed to remove incomplete file after download") - } + if len(randomResponses) == 0 { + // shouldn't happen + return responseTokens, "", errors.New("Dataset is empty") } - return fmt.Errorf("failed to download dataset: %w", err) } - } - d.logger.V(logging.INFO).Info("Using dataset path", "dataset-path", path) - - return d.connectToDB(path, useInMemory) -} - -func (d *CustomDataset) Close() error { - // Release db lock (only for file-based databases) - _, err := d.db.Exec("ROLLBACK;") - if err != nil { - if cerr := d.db.Close(); cerr != nil { - d.logger.Error(cerr, "failed to close database after failing to acquire exclusive lock") + // if response has too much tokens, trim it + if len(randomResponses[0]) > maxResponseLen { + responseTokens = randomResponses[0][:maxResponseLen] } - d.db = nil - return fmt.Errorf("failed to release exclusive lock: %w", err) } - if d.db != nil { - return d.db.Close() - } - return nil -} - -func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { - var tokensList [][]string - for rows.Next() { - var tokensJSON string - if err := rows.Scan(&tokensJSON); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) - } - - var tokens []string - if err := json.Unmarshal([]byte(tokensJSON), &tokens); err != nil { - return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err) - } - tokensList = append(tokensList, tokens) + finishReason := common.StopFinishReason + if len(responseTokens) == maxResponseLen { + finishReason = common.LengthFinishReason } - return tokensList, nil -} -func (d *CustomDataset) GetPromptHash(req openaiserverapi.CompletionRequest) []byte { - hashArray := sha256.Sum256([]byte(req.GetFullPrompt())) - return hashArray[:] -} - -func (d *CustomDataset) GetPromptHashHex(hashBytes []byte) string { - return hex.EncodeToString(hashBytes) -} - -// GetTokens returns tokens and finishReason for the given request and mode (echo or random) -func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) { - if mode == common.ModeEcho { - return d.echo(req) - } - nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random) - tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason, random) - return tokens, finishReason, err -} - -func (d *CustomDataset) query(query string, nTokens int, random *common.Random) ([][]string, error) { - rows, err := d.db.Query(query) - if err != nil { - if !d.hasWarned { - d.logger.Error(err, "failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") - d.hasWarned = true - } - return [][]string{GenPresetRandomTokens(random, nTokens)}, nil - } - defer func() { - if cerr := rows.Close(); cerr != nil { - d.logger.Error(cerr, "failed to close rows after query") - } - }() - return unmarshalAllRecords(rows) -} - -func (d *CustomDataset) GenerateTokens(req openaiserverapi.CompletionRequest, nTokens int, finishReason string, - random *common.Random) ([]string, error) { - // query by prompt hash first - promptHash := d.GetPromptHash(req) - promptHashHex := d.GetPromptHashHex(promptHash) - query := "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + promptHashCol + "=X'" + promptHashHex + "';" - tokensList, err := d.query(query, nTokens, random) - - // filter out results according to finish reason - var filteredTokensList [][]string - if finishReason != LengthFinishReason && finishReason != StopFinishReason { - d.logger.Error(errors.New("unknown finish reason"), "unexpected finish reason", "reason", finishReason) - } - for _, tokens := range tokensList { - if finishReason == StopFinishReason && len(tokens) <= nTokens { - filteredTokensList = append(filteredTokensList, tokens) - } else if finishReason == LengthFinishReason && len(tokens) == nTokens { - filteredTokensList = append(filteredTokensList, tokens) - } - } - tokensList = filteredTokensList - - if err != nil || len(filteredTokensList) == 0 { - switch finishReason { - case LengthFinishReason: - query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "=" + strconv.Itoa(nTokens) + ";" - tokensList, err = d.query(query, nTokens, random) - case StopFinishReason: - query = "SELECT " + genTokensCol + " FROM " + tableName + " WHERE " + nGenTokensCol + "<=" + strconv.Itoa(nTokens) + ";" - tokensList, err = d.query(query, nTokens, random) - } - } - - if err != nil || len(tokensList) == 0 { - // if both queries fail or return no results, generate random tokens - return GenPresetRandomTokens(random, nTokens), nil - } - if d.hasWarned { - d.hasWarned = false - } - randIndex := random.RandomInt(0, len(tokensList)-1) - return tokensList[randIndex], nil + return responseTokens, finishReason, nil } diff --git a/pkg/dataset/custom_dataset_downloader.go b/pkg/dataset/custom_dataset_downloader.go new file mode 100644 index 0000000..4a1d766 --- /dev/null +++ b/pkg/dataset/custom_dataset_downloader.go @@ -0,0 +1,181 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common/logging" +) + +type CustomDatasetDownloader struct { + logger logr.Logger +} + +const ( + progressLogTimeInterval = 5 * time.Second + progressLogPercentInterval = 10 +) + +// progressReader wraps an io.Reader and logs download progress. +type progressReader struct { + io.Reader + total int64 + downloaded int64 + startTime time.Time + lastPct int + lastLogTime time.Time + logger logr.Logger + ctx context.Context +} + +func NewDsDownloader(logger logr.Logger) *CustomDatasetDownloader { + return &CustomDatasetDownloader{logger: logger} +} + +// DownloadDataset downloads dataset from the given url and stores it to the given path +func (d *CustomDatasetDownloader) DownloadDataset(ctx context.Context, url string, path string) error { + folder := filepath.Dir(path) + err := os.MkdirAll(folder, 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + if _, err := os.Stat(path); err == nil { + // file already exists + d.logger.V(logging.INFO).Info("Dataset file already exists, should not download: " + path) + return nil + } + + out, err := os.Create(path) + if err != nil { + return err + } + defer func() { + cerr := out.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close file after download") + err = errors.Join(err, cerr) + } + }() + + d.logger.V(logging.INFO).Info("Using dataset-url", "dataset-url", url) + resp, err := http.Get(url) + if err != nil { + return err + } + defer func() { + cerr := resp.Body.Close() + if cerr != nil { + d.logger.Error(cerr, "failed to close response body after download") + err = errors.Join(err, cerr) + } + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("dataset download bad status: %s", resp.Status) + } + + // Progress reader with context + pr := &progressReader{ + Reader: resp.Body, + total: resp.ContentLength, + logger: d.logger, + ctx: ctx, + startTime: time.Now(), + } + + written, err := io.Copy(out, pr) + if err != nil { + // Remove incomplete file + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + // If context was cancelled, return a specific error + if errors.Is(err, context.Canceled) { + return errors.New("download cancelled by user") + } + return fmt.Errorf("failed to download file: %w", err) + } + // Check if file size is zero + if written == 0 { + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove empty file after download") + } + return errors.New("downloaded file is empty") + } + + // Ensure file is fully flushed and closed before returning success + if err := out.Sync(); err != nil { + cerr := os.Remove(path) + if cerr != nil { + d.logger.Error(cerr, "failed to remove incomplete file after download") + } + return fmt.Errorf("failed to sync file: %w", err) + } + + d.logger.V(logging.INFO).Info("Downloaded dataset from %s, stored in %s\n", url, path) + return nil +} + +func (pr *progressReader) Read(p []byte) (int, error) { + select { + case <-pr.ctx.Done(): + return 0, pr.ctx.Err() + default: + } + n, err := pr.Reader.Read(p) + pr.downloaded += int64(n) + if pr.total > 0 { + pct := int(float64(pr.downloaded) * 100 / float64(pr.total)) + now := time.Now() + + timeSinceLastLog := now.Sub(pr.lastLogTime).Seconds() + pctDiff := pct - pr.lastPct + + if timeSinceLastLog >= progressLogTimeInterval.Seconds() || (pctDiff >= progressLogPercentInterval && pct != pr.lastPct) { + // progress will be shown every interval seconds or every interval percent of progress + pr.logProgress(pct) + pr.lastPct = pct + pr.lastLogTime = now + } + } + return n, err +} + +func (pr *progressReader) logProgress(pct int) { + elapsedTime := time.Since(pr.startTime).Seconds() + speed := float64(pr.downloaded) / (1024 * 1024 * elapsedTime) + remainingTime := float64(pr.total-pr.downloaded) / (float64(pr.downloaded) / elapsedTime) + if pct != 100 { + pr.logger.V(logging.INFO).Info("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime) + fmt.Println("Dataset download progress", "%", pct, "speed (MB/s)", speed, "remaining time (s)", remainingTime) + } else { + pr.logger.V(logging.INFO).Info("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime) + fmt.Println("Download completed", "average speed (MB/s)", speed, "total time (s)", elapsedTime) + } +} diff --git a/pkg/dataset/custom_dataset_test.go b/pkg/dataset/custom_dataset_test.go index 907a076..ef236ff 100644 --- a/pkg/dataset/custom_dataset_test.go +++ b/pkg/dataset/custom_dataset_test.go @@ -22,11 +22,13 @@ import ( "os" "time" + "k8s.io/klog/v2" + "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" - "k8s.io/klog/v2" _ "github.com/mattn/go-sqlite3" ) @@ -37,7 +39,8 @@ const ( var _ = Describe("CustomDataset", Ordered, func() { var ( - dataset *CustomDataset + sqliteHelper *sqliteHelper + dsDownloader *CustomDatasetDownloader file_folder string path string validDBPath string @@ -54,7 +57,8 @@ var _ = Describe("CustomDataset", Ordered, func() { }) BeforeEach(func() { - dataset = &CustomDataset{} + sqliteHelper = newSqliteHelper(klog.Background()) + dsDownloader = NewDsDownloader(klog.Background()) file_folder = ".llm-d" path = file_folder + "/test.sqlite3" err := os.MkdirAll(file_folder, os.ModePerm) @@ -67,15 +71,8 @@ var _ = Describe("CustomDataset", Ordered, func() { pathToInvalidTypeDB = file_folder + "/test.invalid.type.sqlite3" }) - AfterEach(func() { - if dataset.db != nil { - err := dataset.db.Close() - Expect(err).NotTo(HaveOccurred()) - } - }) - It("should return error for invalid DB path", func() { - err := dataset.connectToDB("/invalid/path/to/db.sqlite", false) + err := sqliteHelper.connectToDB("/invalid/path/to/db.sqlite", false) Expect(err).To(HaveOccurred()) }) @@ -86,9 +83,8 @@ var _ = Describe("CustomDataset", Ordered, func() { err = os.Remove(path) Expect(err).NotTo(HaveOccurred()) } - url := "https://llm-d.ai" - err = dataset.downloadDataset(context.Background(), url, path) + err = dsDownloader.DownloadDataset(context.Background(), url, path) Expect(err).NotTo(HaveOccurred()) _, err = os.Stat(path) Expect(err).NotTo(HaveOccurred()) @@ -98,22 +94,23 @@ var _ = Describe("CustomDataset", Ordered, func() { It("should not download file from url", func() { url := "https://256.256.256.256" // invalid url - err := dataset.downloadDataset(context.Background(), url, path) + err := dsDownloader.DownloadDataset(context.Background(), url, path) Expect(err).To(HaveOccurred()) }) It("should successfully init dataset", func() { - err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + dataset := &CustomDataset{} + err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024) Expect(err).NotTo(HaveOccurred()) - row := dataset.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") + row := dataset.sqliteHelper.db.QueryRow("SELECT n_gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") var n_gen_tokens int err = row.Scan(&n_gen_tokens) Expect(err).NotTo(HaveOccurred()) Expect(n_gen_tokens).To(Equal(4)) var jsonStr string - row = dataset.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") + row = dataset.sqliteHelper.db.QueryRow("SELECT gen_tokens FROM llmd WHERE prompt_hash=X'74bf14c09c038321cba39717dae1dc732823ae4abd8e155959367629a3c109a8';") err = row.Scan(&jsonStr) Expect(err).NotTo(HaveOccurred()) var tokens []string @@ -121,33 +118,35 @@ var _ = Describe("CustomDataset", Ordered, func() { Expect(err).NotTo(HaveOccurred()) Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + err = dataset.sqliteHelper.db.Close() + Expect(err).NotTo(HaveOccurred()) }) It("should return error for non-existing DB path", func() { - err := dataset.connectToDB(pathNotExist, false) + err := sqliteHelper.connectToDB(pathNotExist, false) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("database file does not exist")) }) It("should return error for invalid DB file", func() { - err := dataset.connectToDB(pathToInvalidDB, false) + err := sqliteHelper.connectToDB(pathToInvalidDB, false) Expect(err).To(HaveOccurred()) }) It("should return error for DB with invalid table", func() { - err := dataset.connectToDB(pathToInvalidTableDB, false) + err := sqliteHelper.connectToDB(pathToInvalidTableDB, false) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to verify database")) }) It("should return error for DB with invalid column", func() { - err := dataset.connectToDB(pathToInvalidColumnDB, false) + err := sqliteHelper.connectToDB(pathToInvalidColumnDB, false) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("missing expected column")) }) It("should return error for DB with invalid column type", func() { - err := dataset.connectToDB(pathToInvalidTypeDB, false) + err := sqliteHelper.connectToDB(pathToInvalidTypeDB, false) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("incorrect type")) }) @@ -159,8 +158,8 @@ var _ = Describe("CustomDataset", Ordered, func() { req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - - hashBytes := dataset.GetPromptHash(req) + dataset := &CustomDataset{} + hashBytes := dataset.getPromptHash(req) Expect(hashBytes).To(Equal(expectedHashBytes)) }) @@ -170,52 +169,98 @@ var _ = Describe("CustomDataset", Ordered, func() { req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - - hashBytes := dataset.GetPromptHash(req) - hashHex := dataset.GetPromptHashHex(hashBytes) + dataset := &CustomDataset{} + hashBytes := dataset.getPromptHash(req) + hashHex := dataset.getPromptHashHex(hashBytes) Expect(hashHex).To(Equal(expectedHashHex)) }) It("should return tokens for existing prompt", func() { - err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + dataset := &CustomDataset{} + err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024) Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) - Expect(finishReason).To(Equal(StopFinishReason)) + Expect(finishReason).To(Equal(common.StopFinishReason)) if len(tokens) >= 4 { // The number of tokens to generate is random, and if it's less than 4 // we will not get these tokens Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) } + err = dataset.sqliteHelper.db.Close() + Expect(err).NotTo(HaveOccurred()) }) It("should return at most 2 tokens for existing prompt", func() { - err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", false) + dataset := &CustomDataset{} + err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, false, 1024) Expect(err).NotTo(HaveOccurred()) n := int64(2) req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, MaxTokens: &n, } - tokens, _, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, _, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) Expect(len(tokens)).To(BeNumerically("<=", 2)) + err = dataset.sqliteHelper.db.Close() + Expect(err).NotTo(HaveOccurred()) }) It("should successfully init dataset with in-memory option", func() { - err := dataset.Init(context.Background(), klog.Background(), validDBPath, "", true) + dataset := &CustomDataset{} + err := dataset.Init(context.Background(), klog.Background(), random, validDBPath, true, 1024) Expect(err).NotTo(HaveOccurred()) req := &openaiserverapi.TextCompletionRequest{ Prompt: testPrompt, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) - Expect(finishReason).To(Equal(StopFinishReason)) + Expect(finishReason).To(Equal(common.StopFinishReason)) Expect(tokens).To(Equal([]string{"Hello", " llm-d ", "world", "!"})) + err = dataset.sqliteHelper.db.Close() + Expect(err).NotTo(HaveOccurred()) + }) +}) + +var _ = Describe("custom dataset for multiple simulators", Ordered, func() { + It("should not fail on custom datasets initialization", func() { + file_folder := ".llm-d" + validDBPath := file_folder + "/test.valid.sqlite3" + + random1 := common.NewRandom(time.Now().UnixNano(), 8081) + dataset1 := &CustomDataset{} + err := dataset1.Init(context.Background(), klog.Background(), random1, validDBPath, false, 1024) + Expect(err).NotTo(HaveOccurred()) + + random2 := common.NewRandom(time.Now().UnixNano(), 8082) + dataset2 := &CustomDataset{} + err = dataset2.Init(context.Background(), klog.Background(), random2, validDBPath, false, 1024) + Expect(err).NotTo(HaveOccurred()) + }) +}) + +var _ = Describe("download custom dataset from HF", Ordered, func() { + // currently there is only one dataset which is too large + // one we will create a small sample dataset - restore this test + XIt("should download and save ds", func() { + url := "https://huggingface.co/datasets/hf07397/inference-sim-datasets/resolve/91ffa7aafdfd6b3b1af228a517edc1e8f22cd274/huggingface/ShareGPT_Vicuna_unfiltered/conversations.sqlite3" + downloader := NewDsDownloader(klog.Background()) + tempFile := "./ds1.sqlite3" + + if _, err := os.Stat(tempFile); err == nil { + err := os.Remove(tempFile) + Expect(err).NotTo(HaveOccurred()) + } + err := downloader.DownloadDataset(context.Background(), url, tempFile) + Expect(err).NotTo(HaveOccurred()) + + err = os.Remove(tempFile) + Expect(err).NotTo(HaveOccurred()) }) }) diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 15c737a..ec7bd6f 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -26,34 +26,8 @@ import ( _ "github.com/mattn/go-sqlite3" ) -const ( - RoleAssistant = "assistant" - RoleUser = "user" -) - -const ( - ResponseLenMax = 128 - responseLenMean = 40 - responseLenStddev = 20 - stopFinishReasonProbability = 0.8 - - StopFinishReason = "stop" - LengthFinishReason = "length" - ToolsFinishReason = "tool_calls" - RemoteDecodeFinishReason = "remote_decode" -) - -// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response -var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} -var cumulativeBucketsProbabilities []float64 - -const ( - flexBucketIndex = 3 - maxFixedBucketSize = 20 -) - // list of responses to use in random mode for completion requests -var chatCompletionFakeResponses = []string{ +var completionFakeResponses = []string{ `Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`, `Testing, testing 1,2,3.`, `I am fine, how are you today?`, @@ -68,55 +42,113 @@ var chatCompletionFakeResponses = []string{ } type Dataset interface { - // Init initializes the dataset using configs - Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error // Close closes the dataset Close() error // GetTokens returns tokens for the given request and mode (echo or random) - GetTokens(req openaiserverapi.CompletionRequest, mode string, random *common.Random) ([]string, string, error) + GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) +} + +type BaseDataset struct { + logger logr.Logger + maxModelLen int + random *common.Random + histogramHelper *histogramHelper +} + +func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, random *common.Random, maxModelLen int) error { + d.logger = logger + d.maxModelLen = maxModelLen + d.random = random + d.histogramHelper = newHistogramHelper(d.random) + + return nil +} + +func (d *BaseDataset) Close() error { + return nil +} + +// GetTokens returns tokens and finishReason for the given request and mode (echo or random) +func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string) ([]string, string, error) { + if mode == common.ModeEcho { + return d.getTokensInEchoMode(req) + } + + maxRespTokens, isMaxTokensInReq := d.calculateResponseMaxLen(req) + + numOfRespTokens := 0 + finishReason := common.StopFinishReason + + switch { + case req.GetIgnoreEOS(): + // ignore_eos is true - response must have the maximum number of tokens + numOfRespTokens = maxRespTokens + case isMaxTokensInReq: + // max tokens is defined in the request - generate number of tokens in the response based on the histogram + numOfRespTokens = d.histogramHelper.getResponseLengthByHistogram(maxRespTokens) + default: + // no tokens limitation in the request - use gaussian with the mean (currently hard-coded) + numOfRespTokens = d.getRandomResponseLenByGaussian(maxRespTokens) + } + + if numOfRespTokens == maxRespTokens { + // if response should be create with maximum number of tokens - finish reason will be 'length' + finishReason = common.LengthFinishReason + } + + return d.generatePresetRandomTokens(numOfRespTokens), finishReason, nil } -func init() { - cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) - sum := 0.0 +func (d *BaseDataset) getTokensInEchoMode(req openaiserverapi.CompletionRequest) ([]string, string, error) { + tokens := common.Tokenize(req.ExtractPrompt()) + maxTokens := req.GetMaxCompletionTokens() + finishReason := common.StopFinishReason - for i, val := range respLenBucketsProbabilities { - sum += val - cumulativeBucketsProbabilities[i] = sum + if maxTokens != nil && len(tokens) >= int(*maxTokens) { + finishReason = common.LengthFinishReason } + + return tokens, finishReason, nil } -// GetRandomResponseLen returns int in range [1, responseLenMax] +// calculateResponseMaxLen - calculates maximum length of a response to be randomly chosen from the dataset +// for the given request and the simulator configuration. +// If max-tokens/max-completion-tokens is defined - use it, +// otherwise use - +// boolean returned value defines whether max tokens number was passed in the request +func (d *BaseDataset) calculateResponseMaxLen(req openaiserverapi.CompletionRequest) (int, bool) { + maxTokens := req.GetMaxCompletionTokens() + + if maxTokens != nil { + return int(*maxTokens), true + } + + return d.maxModelLen - len(common.Tokenize(req.GetFullPrompt())), false +} + +// getRandomResponseLenByDistribution returns int in range [1, responseLenMax] // numbers are chosen according a gaussian distribution with mean responseLenMean, and standard deviation responseLenStddev -func GetRandomResponseLen(random *common.Random) int { +func (d *BaseDataset) getRandomResponseLenByGaussian(maxLen int) int { for { - val := random.RandomNorm(responseLenMean, responseLenStddev) - if val >= 1 && val <= ResponseLenMax { + val := d.random.RandomNorm(responseLenMean, responseLenStddev) + if val >= 1 && val <= float64(maxLen) { return int(math.Round(val)) } // else reject and resample } } -// GetRandomFinishReason returns finish reason with the probability for 'stop' as defined by stopFinishReasonProbability -func GetRandomFinishReason(random *common.Random) string { - if random.RandomFloat(0, 1) < stopFinishReasonProbability { - return StopFinishReason - } - return LengthFinishReason -} - -// GenPresetRandomTokens generates random tokens for the required number of tokens, +// generatePresetRandomTokens generates random tokens for the required number of tokens, // select randomly a sentence from chatCompletionFakeResponses, // if number of tokens is lower than required - select another sentence, -// continue until the required number of tokens is achieved -func GenPresetRandomTokens(random *common.Random, numOfTokens int) []string { +// continue until the required number of tokens is achieved, +// returned exactly tokens +func (d BaseDataset) generatePresetRandomTokens(numOfTokens int) []string { allTokens := make([]string, 0) for len(allTokens) < numOfTokens { - index := random.RandomInt(0, len(chatCompletionFakeResponses)-1) - // create tokens from text, splitting by spaces and special characters - tokens := common.Tokenize(chatCompletionFakeResponses[index]) + index := d.random.RandomInt(0, len(completionFakeResponses)-1) + tokens := common.Tokenize(completionFakeResponses[index]) remaining := numOfTokens - len(allTokens) if len(tokens) > remaining { @@ -134,171 +166,3 @@ func GenPresetRandomTokens(random *common.Random, numOfTokens int) []string { return allTokens } - -// howManyTokensToGen generates the number of tokens to be returned in a response, and the finish reason (see constants) -// if maxCompletionTokens is defined -// - currently, the generated number of words in the text will be equal to it value -// - in future - need to find statistics about generated tokens distribution and return less tokens in part os requests -// - finish reason will be chosen randomly from the collection (stop, length) with 80% for stop and 20% for length -// if maxCompletionTokens is nil -// - the response text's length is randomly chosen from the range [1, responseLenMax] according additional parameters -// - finish reason is stop -// if ignore_eos is true - the response will be generated with exactly maxCompletionTokens tokens -// - request was validated so that when ignore_eos is true, maxCompletionTokens must be defined -func howManyTokensToGen(maxCompletionTokens *int64, ignore_eos bool, random *common.Random) (int, string) { - numOfTokens := 0 - finishReason := StopFinishReason - - // no max completion tokens, return text with random length - if maxCompletionTokens == nil { - numOfTokens = GetRandomResponseLen(random) - } else { - maxTokens := int(*maxCompletionTokens) - if ignore_eos { - numOfTokens = maxTokens - finishReason = LengthFinishReason - } else { - // max tokens is defined - generate real length of the response based on it - numOfTokens = getResponseLengthByHistogram(random, maxTokens) - if numOfTokens == maxTokens { - // if response should be create with maximum number of tokens - finish reason will be 'length' - finishReason = LengthFinishReason - } - } - } - - return numOfTokens, finishReason -} - -// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. -// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. -// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. -// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. -// Other values define probabilities for the equally sized buckets. -// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] -func getResponseLengthByHistogram(random *common.Random, maxTokens int) int { - if maxTokens <= 1 { - return maxTokens - } - // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] - if maxTokens <= len(cumulativeBucketsProbabilities) { - res := random.RandomInt(1, maxTokens) - return res - } - - r := random.RandomFloat(0, 1) - - // check if r is in the last bucket, then maxTokens should be returned - if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] { - return maxTokens - } - - // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use - // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 - bucketIndex := len(cumulativeBucketsProbabilities) - 1 - for i, c := range cumulativeBucketsProbabilities { - if r <= c { - bucketIndex = i - break - } - } - - // calculate the size of all of the buckets (except the special last bucket) - start, end := calcBucketBoundaries(maxTokens, bucketIndex) - - // pick uniformly within the bucket’s range - return random.RandomInt(start, end) -} - -// calcBucketBoundaries calculates boundaries of a bucket with the given index. -// Maximum size for equally sized buckets is defined by maxFixedBucketSize. -// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which -// division to equally size buckets will give buckets with size maxFixedBucketSize. -// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, -// all buckets will be of equal size, except the last bucket, which contains only one value. -// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], -// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. -// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), -// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. -func calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { - maxEquallyBucketsSz := maxFixedBucketSize*(len(cumulativeBucketsProbabilities)-1) + 1 - - if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(cumulativeBucketsProbabilities)-1 { - // create equally size buckets - // calculate the size of all of the buckets (except the special last bucket) - bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1) - start = int(bucketSize*float64(bucketIndex)) + 1 - end = int(bucketSize * float64(bucketIndex+1)) - } else { - // create non-equally sized buckets and find boundaries of the required bucket - if bucketIndex < flexBucketIndex { - // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) - // start is the minimum number in the required bucket - start = maxFixedBucketSize*bucketIndex + 1 - end = maxFixedBucketSize * (bucketIndex + 1) - } else { - flexBucketSize := maxTokens - (maxFixedBucketSize * (len(cumulativeBucketsProbabilities) - 2)) - - if bucketIndex == flexBucketIndex { - // the relevant bucket is the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } else { - // the relevant bucket is one of buckets after the flex bucket - start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 - end = maxFixedBucketSize*bucketIndex + flexBucketSize - } - } - } - - // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 - if end >= maxTokens { - end = maxTokens - 1 - } - - return start, end -} - -// EchoResponseTokens returns needed tokens, from a given text -// considering max completion tokens if it is not nil, and a finish reason (stop or length) -func EchoResponseTokens(maxCompletionTokens *int64, text string) ([]string, string) { - tokens := common.Tokenize(text) - // no max completion tokens, return entire text - if maxCompletionTokens == nil { - return tokens, StopFinishReason - } - - if *maxCompletionTokens >= int64(len(tokens)) { - return tokens, StopFinishReason - } - // return truncated text - return tokens[0:*maxCompletionTokens], LengthFinishReason -} - -type BaseDataset struct { - logger logr.Logger -} - -func (d *BaseDataset) Init(ctx context.Context, logger logr.Logger, path string, url string, useInMemory bool) error { - d.logger = logger - return nil -} - -func (d *BaseDataset) Close() error { - return nil -} - -func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) { - tokens, finishReason := EchoResponseTokens(req.ExtractMaxTokens(), req.ExtractPrompt()) - return tokens, finishReason, nil -} - -// GetTokens returns tokens and finishReason for the given request and mode (echo or random) -func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode string, - random *common.Random) ([]string, string, error) { - if mode == common.ModeEcho { - return d.echo(req) - } - nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS(), random) - return GenPresetRandomTokens(random, nTokensToGen), finishReason, nil -} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index 3e7974a..8efa170 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -17,39 +17,53 @@ limitations under the License. package dataset import ( + "context" "fmt" "strings" + "time" + "github.com/go-logr/logr" + "github.com/go-logr/logr/funcr" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +func NewStdoutLogger() logr.Logger { + return funcr.New(func(prefix, args string) { + if prefix != "" { + fmt.Printf("%s: %s\n", prefix, args) + } else { + fmt.Println(args) + } + }, funcr.Options{}) +} + var _ = Describe("Dataset", Ordered, func() { var ( dataset *BaseDataset - random *common.Random ) - BeforeAll(func() { - random = common.NewRandom(time.Now().UnixNano(), 8080) - }) + createDataset := func() { + dataset = &BaseDataset{} + err := dataset.Init(context.Background(), NewStdoutLogger(), common.NewRandom(time.Now().UnixNano(), 8080), 1024) + Expect(err).ShouldNot(HaveOccurred()) + } BeforeEach(func() { - dataset = &BaseDataset{} + createDataset() }) Context("GetRandomTokens", func() { - It("should return complete text", func() { req := &openaiserverapi.ChatCompletionRequest{} - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) - Expect(finishReason).Should(Equal(StopFinishReason)) + Expect(finishReason).Should(Equal(common.StopFinishReason)) }) It("should return short text", func() { @@ -57,35 +71,34 @@ var _ = Describe("Dataset", Ordered, func() { req := &openaiserverapi.ChatCompletionRequest{ MaxCompletionTokens: &maxCompletionTokens, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) + Expect(finishReason).To(Equal(common.LengthFinishReason)) } else { Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) + Expect(finishReason).To(Equal(common.StopFinishReason)) } }) It("should return long text", func() { - // return required number of tokens although it is higher than ResponseLenMax - maxCompletionTokens := int64(ResponseLenMax * 5) + maxCompletionTokens := int64(1000) req := &openaiserverapi.ChatCompletionRequest{ MaxTokens: &maxCompletionTokens, } - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) tokensCnt := int64(len(tokens)) Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens)) text := strings.Join(tokens, "") Expect(IsValidText(text)).To(BeTrue()) if tokensCnt == maxCompletionTokens { - Expect(finishReason).To(Equal(LengthFinishReason)) + Expect(finishReason).To(Equal(common.LengthFinishReason)) } else { Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens)) - Expect(finishReason).To(Equal(StopFinishReason)) + Expect(finishReason).To(Equal(common.StopFinishReason)) } }) @@ -96,11 +109,11 @@ var _ = Describe("Dataset", Ordered, func() { MaxTokens: &n, } req.SetIgnoreEOS(true) - tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom, random) + tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) nGenTokens := int64(len(tokens)) Expect(nGenTokens).Should(Equal(n)) - Expect(finishReason).To(Equal(LengthFinishReason)) + Expect(finishReason).To(Equal(common.LengthFinishReason)) }, func(maxCompletionTokens int) string { return fmt.Sprintf("maxCompletionTokens: %d", maxCompletionTokens) @@ -116,22 +129,36 @@ var _ = Describe("Dataset", Ordered, func() { theText := "Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime" theTokens := common.Tokenize(theText) - It("should return the same text since max tokens is not defined", func() { - tokens, finishReason := EchoResponseTokens(nil, theText) + It("should return the same text, max tokens is not defined", func() { + req := &openaiserverapi.TextCompletionRequest{ + Prompt: theText, + } + tokens, finishReason, err := dataset.getTokensInEchoMode(req) + Expect(err).ShouldNot(HaveOccurred()) Expect(tokens).Should(Equal(theTokens)) - Expect(finishReason).Should(Equal(StopFinishReason)) + Expect(finishReason).Should(Equal(common.StopFinishReason)) }) - It("should return the same text since max tokens is higher than the text length", func() { - maxCompletionTokens := int64(1000) - tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) + It("should return the same text, max tokens is higher than the text length", func() { + maxTokens := int64(1000) + req := &openaiserverapi.TextCompletionRequest{ + Prompt: theText, + MaxTokens: &maxTokens, + } + tokens, finishReason, err := dataset.getTokensInEchoMode(req) + Expect(err).ShouldNot(HaveOccurred()) Expect(tokens).Should(Equal(theTokens)) - Expect(finishReason).Should(Equal(StopFinishReason)) + Expect(finishReason).Should(Equal(common.StopFinishReason)) }) - It("should return partial text", func() { - maxCompletionTokens := int64(2) - tokens, finishReason := EchoResponseTokens(&maxCompletionTokens, theText) - Expect(int64(len(tokens))).Should(Equal(maxCompletionTokens)) - Expect(finishReason).Should(Equal(LengthFinishReason)) + It("should return the same text, finish reason is stop", func() { + maxTokens := int64(2) + req := &openaiserverapi.TextCompletionRequest{ + Prompt: theText, + MaxTokens: &maxTokens, + } + tokens, finishReason, err := dataset.getTokensInEchoMode(req) + Expect(err).ShouldNot(HaveOccurred()) + Expect(tokens).Should(Equal(theTokens)) + Expect(finishReason).Should(Equal(common.LengthFinishReason)) }) }) @@ -141,7 +168,7 @@ var _ = Describe("Dataset", Ordered, func() { for _, len := range lenArr { name := fmt.Sprintf("should return text with %d tokens", len) It(name, func() { - tokens := GenPresetRandomTokens(random, len) + tokens := dataset.generatePresetRandomTokens(len) Expect(tokens).Should(HaveLen(len)) }) } @@ -151,15 +178,15 @@ var _ = Describe("Dataset", Ordered, func() { validTxts := make([]string, 0) invalidTxts := make([]string, 0) - validTxts = append(validTxts, chatCompletionFakeResponses[0][:4]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]) - validTxts = append(validTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + validTxts = append(validTxts, completionFakeResponses[0][:4]) + validTxts = append(validTxts, completionFakeResponses[1]) + validTxts = append(validTxts, completionFakeResponses[1]+" "+completionFakeResponses[2]) - invalidTxts = append(invalidTxts, (chatCompletionFakeResponses[1] + " " + chatCompletionFakeResponses[2])[3:4]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[0][4:]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+"-"+chatCompletionFakeResponses[2]) - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" ") - invalidTxts = append(invalidTxts, chatCompletionFakeResponses[1]+" "+chatCompletionFakeResponses[2]) + invalidTxts = append(invalidTxts, (completionFakeResponses[1] + " " + completionFakeResponses[2])[3:4]) + invalidTxts = append(invalidTxts, completionFakeResponses[0][4:]) + invalidTxts = append(invalidTxts, completionFakeResponses[1]+"-"+completionFakeResponses[2]) + invalidTxts = append(invalidTxts, completionFakeResponses[1]+" ") + invalidTxts = append(invalidTxts, completionFakeResponses[1]+" "+completionFakeResponses[2]) for _, txt := range validTxts { It("text should be valid", func() { @@ -175,6 +202,9 @@ var _ = Describe("Dataset", Ordered, func() { }) Context("validateBucketsBoundaries", func() { + // create dataset earlier than BeforeEach since it's helper is used before It execution + createDataset() + type bucketBoundaries struct { start int end int @@ -189,11 +219,11 @@ var _ = Describe("Dataset", Ordered, func() { {50, []bucketBoundaries{{1, 9}, {10, 19}, {20, 29}, {30, 39}, {40, 49}}}} for _, test := range tests { - Expect(test.expectedBuckets).To(HaveLen(len(cumulativeBucketsProbabilities) - 1)) + Expect(test.expectedBuckets).To(HaveLen(len(dataset.histogramHelper.cumulativeBucketsProbabilities) - 1)) It(fmt.Sprintf("should return bucket boundaries for maxTokens %d", test.maxTokens), func() { - for i := range len(cumulativeBucketsProbabilities) - 1 { - start, end := calcBucketBoundaries(test.maxTokens, i) + for i := range len(dataset.histogramHelper.cumulativeBucketsProbabilities) - 1 { + start, end := dataset.histogramHelper.calcBucketBoundaries(test.maxTokens, i) Expect(start).To(Equal(test.expectedBuckets[i].start)) Expect(end).To(Equal(test.expectedBuckets[i].end)) } diff --git a/pkg/dataset/histogram_helper.go b/pkg/dataset/histogram_helper.go new file mode 100644 index 0000000..8107fad --- /dev/null +++ b/pkg/dataset/histogram_helper.go @@ -0,0 +1,139 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import "github.com/llm-d/llm-d-inference-sim/pkg/common" + +const ( + responseLenMean = 40 + responseLenStddev = 20 +) + +// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response +var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15} + +const ( + flexBucketIndex = 3 + maxFixedBucketSize = 20 +) + +type histogramHelper struct { + cumulativeBucketsProbabilities []float64 + random *common.Random +} + +func newHistogramHelper(random *common.Random) *histogramHelper { + h := histogramHelper{random: random} + + h.cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities)) + sum := 0.0 + + for i, val := range respLenBucketsProbabilities { + sum += val + h.cumulativeBucketsProbabilities[i] = sum + } + return &h +} + +// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets. +// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities. +// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value. +// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens. +// Other values define probabilities for the equally sized buckets. +// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens] +func (hh *histogramHelper) getResponseLengthByHistogram(maxTokens int) int { + if maxTokens <= 1 { + return maxTokens + } + // maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens] + if maxTokens <= len(hh.cumulativeBucketsProbabilities) { + res := hh.random.RandomInt(1, maxTokens) + return res + } + + r := hh.random.RandomFloat(0, 1) + + // check if r is in the last bucket, then maxTokens should be returned + if r > hh.cumulativeBucketsProbabilities[len(hh.cumulativeBucketsProbabilities)-2] { + return maxTokens + } + + // determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use + // initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1 + bucketIndex := len(hh.cumulativeBucketsProbabilities) - 1 + for i, c := range hh.cumulativeBucketsProbabilities { + if r <= c { + bucketIndex = i + break + } + } + + // calculate the size of all of the buckets (except the special last bucket) + start, end := hh.calcBucketBoundaries(maxTokens, bucketIndex) + + // pick uniformly within the bucket’s range + return hh.random.RandomInt(start, end) +} + +// calcBucketBoundaries calculates boundaries of a bucket with the given index. +// Maximum size for equally sized buckets is defined by maxFixedBucketSize. +// [maxFixedBucketSize*(number-of-buckets-1)+1] is the value of maxTokens for which +// division to equally size buckets will give buckets with size maxFixedBucketSize. +// If maxTokens is [maxFixedBucketSize*(number-of-buckets-1)+1] or less, +// all buckets will be of equal size, except the last bucket, which contains only one value. +// If maxTokens is higher than [maxFixedBucketSize*(number-of-buckets-1)+1], +// and flexBucketIndex is valid (between 0 and number of buckets - 1) the buckets sizes will not be equal. +// In this case, all buckets except the one at flexBucketIndex index will have size 20 (and the last is with size 1), +// and the bucket at flexBucketIndex index will 'stretch' to cover the remaining range. +func (hh *histogramHelper) calcBucketBoundaries(maxTokens int, bucketIndex int) (start int, end int) { + maxEquallyBucketsSz := maxFixedBucketSize*(len(hh.cumulativeBucketsProbabilities)-1) + 1 + + if maxTokens <= maxEquallyBucketsSz || flexBucketIndex < 0 || flexBucketIndex >= len(hh.cumulativeBucketsProbabilities)-1 { + // create equally size buckets + // calculate the size of all of the buckets (except the special last bucket) + bucketSize := float64(maxTokens-1) / float64(len(hh.cumulativeBucketsProbabilities)-1) + start = int(bucketSize*float64(bucketIndex)) + 1 + end = int(bucketSize * float64(bucketIndex+1)) + } else { + // create non-equally sized buckets and find boundaries of the required bucket + if bucketIndex < flexBucketIndex { + // the relevant bucket is before the flex bucket, all buckets are of the same size (maxFixedBucketSize) + // start is the minimum number in the required bucket + start = maxFixedBucketSize*bucketIndex + 1 + end = maxFixedBucketSize * (bucketIndex + 1) + } else { + flexBucketSize := maxTokens - (maxFixedBucketSize * (len(hh.cumulativeBucketsProbabilities) - 2)) + + if bucketIndex == flexBucketIndex { + // the relevant bucket is the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex)) + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } else { + // the relevant bucket is one of buckets after the flex bucket + start = int(maxFixedBucketSize*float64(bucketIndex-1)) + flexBucketSize + 1 + end = maxFixedBucketSize*bucketIndex + flexBucketSize + } + } + } + + // sometimes end could be maxTokens because of rounding, change the value to maxToken-1 + if end >= maxTokens { + end = maxTokens - 1 + } + + return start, end +} diff --git a/pkg/dataset/sqlite_helper.go b/pkg/dataset/sqlite_helper.go new file mode 100644 index 0000000..ca45dcd --- /dev/null +++ b/pkg/dataset/sqlite_helper.go @@ -0,0 +1,300 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dataset + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "os" + "strconv" + "time" + + "github.com/go-logr/logr" + "github.com/llm-d/llm-d-inference-sim/pkg/common/logging" +) + +// use constants for expected column names and types +const ( + tableName = "llmd" + idCol = "id" + promptHashCol = "prompt_hash" + genTokensCol = "gen_tokens" + nGenTokensCol = "n_gen_tokens" + idColType = "INTEGER" + promptHashColType = "BLOB" + genTokensColType = "JSON" + nGenTokensColType = "INTEGER" +) + +type sqliteHelper struct { + logger logr.Logger + db *sql.DB +} + +func newSqliteHelper(logger logr.Logger) *sqliteHelper { + return &sqliteHelper{logger: logger} +} + +func (s *sqliteHelper) connectToDB(path string, useInMemory bool) error { + if s.db != nil { + err := s.db.Close() + if err != nil { + s.logger.Error(err, "failed to close existing database connection") + } + s.db = nil + } + // check if file exists + _, err := os.Stat(path) + if err != nil { + return fmt.Errorf("database file does not exist: %w", err) + } + + if useInMemory { + err = s.loadDatabaseInMemory(path) + if err != nil { + return err + } + } else { + // Use file-based database (original behavior) + s.db, err = sql.Open("sqlite3", "file:"+path+"?mode=ro") + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + + // Check if there are other connections to the database + _, err = s.db.Exec("BEGIN EXCLUSIVE;") + if err != nil { + if closeErr := s.db.Close(); closeErr != nil { + s.logger.Error(closeErr, "failed to close database after failing to acquire exclusive lock") + } + s.db = nil + return fmt.Errorf("database is locked or has other active connections: %w", err) + } + } + + err = s.verifyDB() + if err != nil { + return fmt.Errorf("failed to verify database: %w", err) + } + + count, err := s.getRecordsCount() + if err != nil { + s.logger.Error(err, "failed to get records count") + return fmt.Errorf("failed to query database: %w", err) + } + + if useInMemory { + s.logger.V(logging.INFO).Info("In-memory database connected successfully", "path", path, "records count", count) + } else { + s.logger.V(logging.INFO).Info("Database connected successfully", "path", path, "records count", count) + } + return nil +} + +func (s *sqliteHelper) loadDatabaseInMemory(path string) error { + s.logger.V(logging.INFO).Info("Loading database into memory...") + start := time.Now() + + // Create in-memory database + var err error + s.db, err = sql.Open("sqlite3", ":memory:") + if err != nil { + return fmt.Errorf("failed to create in-memory database: %w", err) + } + + // Use ATTACH to copy the database + attachSQL := fmt.Sprintf("ATTACH DATABASE '%s' AS source", path) + _, err = s.db.Exec(attachSQL) + if err != nil { + if closeErr := s.db.Close(); closeErr != nil { + s.logger.Error(closeErr, "failed to close in-memory database after attach failure") + } + s.db = nil + return fmt.Errorf("failed to attach source database: %w", err) + } + + // Copy the table structure first + createTableStmt := fmt.Sprintf(`CREATE TABLE %s ( + id INTEGER PRIMARY KEY, + prompt_hash BLOB, + gen_tokens JSON, + n_gen_tokens INTEGER + )`, tableName) + _, err = s.db.Exec(createTableStmt) + if err != nil { + if closeErr := s.db.Close(); closeErr != nil { + s.logger.Error(closeErr, "failed to close in-memory database after create table failure") + } + s.db = nil + return fmt.Errorf("failed to create table: %w", err) + } + + // Copy the data + _, err = s.db.Exec("INSERT INTO " + tableName + " SELECT * FROM source." + tableName) + if err != nil { + if closeErr := s.db.Close(); closeErr != nil { + s.logger.Error(closeErr, "failed to close in-memory database after copy failure") + } + s.db = nil + return fmt.Errorf("failed to copy data: %w", err) + } + + // Detach the source database + _, err = s.db.Exec("DETACH DATABASE source") + if err != nil { + s.logger.Error(err, "failed to detach source database") + } + + loadTime := time.Since(start) + s.logger.V(logging.INFO).Info("Database loaded into memory", "load_time", loadTime.String()) + return nil +} + +func (s *sqliteHelper) verifyDB() error { + rows, err := s.db.Query("PRAGMA table_info(" + tableName + ");") + if err != nil { + return fmt.Errorf("failed to query table info for `%s`: %w", tableName, err) + } + defer func() { + if cerr := rows.Close(); cerr != nil { + s.logger.Error(cerr, "failed to close rows after querying table info") + } + }() + + expectedColumns := map[string]string{ + idCol: idColType, + promptHashCol: promptHashColType, + genTokensCol: genTokensColType, + nGenTokensCol: nGenTokensColType, + } + + columnsFound := make(map[string]bool) + + var ( + columnName string + columnType string + cid int + notnull int + dfltValue interface{} + pk int + ) + + for rows.Next() { + err := rows.Scan(&cid, &columnName, &columnType, ¬null, &dfltValue, &pk) + if err != nil { + return fmt.Errorf("failed to scan table info row: %w", err) + } + if expectedType, exists := expectedColumns[columnName]; exists { + if columnType != expectedType { + return fmt.Errorf("column %s has incorrect type: expected %s, got %s", columnName, expectedType, columnType) + } + columnsFound[columnName] = true + } + } + + for col := range expectedColumns { + if !columnsFound[col] { + return fmt.Errorf("missing expected column in %s table: %s", tableName, col) + } + } + + return nil +} + +func (s *sqliteHelper) getRecordsCount() (int, error) { + var count int + err := s.db.QueryRow("SELECT COUNT(" + promptHashCol + ") FROM " + tableName + ";").Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to query database: %w", err) + } + return count, nil +} + +// query runs a SQL query which retrieves response tokens as an array of strings +// returns multuple responses +func (s *sqliteHelper) query(query string) ([][]string, error) { + rows, err := s.db.Query(query) + if err != nil { + s.logger.Error(err, "failed to query database. Ensure dataset file is still valid. Will generate random tokens instead.") + return nil, err + } + defer func() { + if cerr := rows.Close(); cerr != nil { + s.logger.Error(cerr, "failed to close rows after query") + if err == nil { + err = cerr + } else { + err = errors.Join(err, cerr) + } + } + }() + return unmarshalAllRecords(rows) +} + +func unmarshalAllRecords(rows *sql.Rows) ([][]string, error) { + var responses [][]string + + for rows.Next() { + var responseJSON string + if err := rows.Scan(&responseJSON); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + var tokens []string + if err := json.Unmarshal([]byte(responseJSON), &tokens); err != nil { + return nil, fmt.Errorf("failed to unmarshal tokens JSON: %w", err) + } + responses = append(responses, tokens) + } + return responses, nil +} + +func (s *sqliteHelper) buildQuery(where string, isRand bool, isLimitOne bool) string { + query := "SELECT " + genTokensCol + " FROM " + tableName + + if where != "" { + query += " WHERE " + where + } + + if isRand { + query += " ORDER BY RANDOM()" + } + + if isLimitOne { + query += " LIMIT 1" + } + query += ";" + + return query +} + +func (s *sqliteHelper) getResponsesForPrompt(promptHashHex string) ([][]string, error) { + query := s.buildQuery(promptHashCol+"=X'"+promptHashHex+"'", false, false) + return s.query(query) +} + +func (s *sqliteHelper) getResponsesForLen(maxLen int) ([][]string, error) { + query := s.buildQuery(nGenTokensCol+"<="+strconv.Itoa(maxLen), true, true) + return s.query(query) +} + +func (s *sqliteHelper) getRandomResponse() ([][]string, error) { + query := s.buildQuery("", true, true) + return s.query(query) +} diff --git a/pkg/dataset/test_helpers.go b/pkg/dataset/test_helpers.go index ed6b6f1..7cde861 100644 --- a/pkg/dataset/test_helpers.go +++ b/pkg/dataset/test_helpers.go @@ -27,7 +27,7 @@ func IsValidText(text string) bool { textToCheck := text[charsTested:] found := false - for _, fakeSentence := range chatCompletionFakeResponses { + for _, fakeSentence := range completionFakeResponses { if len(textToCheck) <= len(fakeSentence) { if strings.HasPrefix(fakeSentence, textToCheck) { found = true diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index cb1ab54..c0ca219 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -247,6 +247,14 @@ func (s *VllmSimulator) Start(ctx context.Context) error { return err } + if s.config.DatasetURL != "" { + // if should use remote responses dataset, download it first (it can take time) + downloader := dataset.NewDsDownloader(s.logger) + if err := downloader.DownloadDataset(ctx, s.config.DatasetURL, s.config.DatasetPath); err != nil { + return err + } + } + // For Data Parallel, start data-parallel-size - 1 additional simulators g, ctx := errgroup.WithContext(ctx) if s.config.DPSize > 1 { @@ -371,32 +379,27 @@ func (s *VllmSimulator) initializeSim(ctx context.Context) error { } func (s *VllmSimulator) initDataset(ctx context.Context) error { - randDataset := &dataset.BaseDataset{} - err := randDataset.Init(ctx, s.logger, "", "", false) - if err != nil { - return fmt.Errorf("failed to initialize random dataset: %w", err) - } - if s.config.DatasetPath == "" && s.config.DatasetURL == "" { + // use predefined sentences as responses + randDataset := &dataset.BaseDataset{} + err := randDataset.Init(ctx, s.logger, s.random, s.config.MaxModelLen) + if err != nil { + return fmt.Errorf("failed to initialize random dataset: %w", err) + } s.logger.V(logging.INFO).Info("No dataset path or URL provided, using random text for responses") s.dataset = randDataset return nil } + // use dataset containing responses custDataset := &dataset.CustomDataset{} - err = custDataset.Init(ctx, s.logger, s.config.DatasetPath, s.config.DatasetURL, s.config.DatasetInMemory) + err := custDataset.Init(ctx, s.logger, s.random, s.config.DatasetPath, s.config.DatasetInMemory, s.config.MaxModelLen) if err == nil { s.dataset = custDataset return nil } - if strings.HasPrefix(err.Error(), "database is locked") { - s.logger.V(logging.WARN).Info("Database is locked by another process, will use preset text for responses instead") - s.dataset = randDataset - return nil - } - return err } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 38e0ad9..a08d946 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -180,16 +180,16 @@ var _ = Describe("Simulator", func() { msg := resp.Choices[0].Message.Content Expect(msg).ShouldNot(BeEmpty()) - if numTokens > 0 { - tokens := common.Tokenize(msg) - Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + if mode == common.ModeEcho { + // in case of echo mode check that the text is returned as-is + Expect(msg).Should(Equal(testUserMessage)) } else { - if mode == common.ModeRandom { + if numTokens > 0 { + tokens := common.Tokenize(msg) + Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + } else { // in case of random mode ensure that the returned message could be output of the random text generator Expect(dataset.IsValidText(msg)).To(BeTrue()) - } else { - // in case of echo mode check that the text is returned as-is - Expect(msg).Should(Equal(testUserMessage)) } } }, @@ -251,16 +251,16 @@ var _ = Describe("Simulator", func() { text := resp.Choices[0].Text Expect(text).ShouldNot(BeEmpty()) - if numTokens != 0 { - tokens := common.Tokenize(text) - Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + if mode == common.ModeEcho { + // in case of echo mode check that the text is returned as-is + Expect(text).Should(Equal(testUserMessage)) } else { - if mode == common.ModeRandom { + if numTokens != 0 { + tokens := common.Tokenize(text) + Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + } else { // in case of random mode ensure that the returned message could be output of the random text generator Expect(dataset.IsValidText(text)).To(BeTrue()) - } else { - // in case of echo mode check that the text is returned as-is - Expect(text).Should(Equal(testUserMessage)) } } }, diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 918408a..7c96fa9 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -26,7 +26,6 @@ import ( "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/llm-d/llm-d-inference-sim/pkg/common/logging" - "github.com/llm-d/llm-d-inference-sim/pkg/dataset" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" ) @@ -142,7 +141,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk var finishReasonToSend *string - if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { + if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) { finishReasonToSend = &finishReason } if context.isChatCompletion { @@ -161,7 +160,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // send the last chunk if finish reason is stop var chunk openaiserverapi.CompletionRespChunk - if finishReason == dataset.StopFinishReason { + if finishReason == common.StopFinishReason { if context.isChatCompletion { chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason) } else { diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index aa6f54c..9edab45 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/llm-d/llm-d-inference-sim/pkg/common" - "github.com/llm-d/llm-d-inference-sim/pkg/dataset" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go/v3" @@ -397,7 +396,7 @@ var _ = Describe("Simulator for request with tools", func() { for _, choice := range chunk.Choices { if choice.Delta.Role != "" { role = choice.Delta.Role - } else if choice.FinishReason == "" || choice.FinishReason == dataset.ToolsFinishReason { + } else if choice.FinishReason == "" || choice.FinishReason == common.ToolsFinishReason { toolCalls := choice.Delta.ToolCalls Expect(toolCalls).To(HaveLen(1)) tc := toolCalls[0] diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index a256d19..041b3b4 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -25,7 +25,6 @@ import ( "github.com/go-logr/logr" "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/llm-d/llm-d-inference-sim/pkg/common/logging" - "github.com/llm-d/llm-d-inference-sim/pkg/dataset" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/valyala/fasthttp" ) @@ -109,12 +108,12 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe req.GetTools() != nil { toolCalls, completionTokens, err = common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config, s.random) - finishReason = dataset.ToolsFinishReason + finishReason = common.ToolsFinishReason } if toolCalls == nil && err == nil { // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. - responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode, s.random) + responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) completionTokens += len(responseTokens) } if err != nil { @@ -154,7 +153,7 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe } else { if req.IsDoRemoteDecode() { // in case this is prefill pod processing, return special finish reason - finishReason = dataset.RemoteDecodeFinishReason + finishReason = common.RemoteDecodeFinishReason } s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData) wg.Done()